Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.lens.ml.algo.spark; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import org.apache.lens.server.api.error.LensException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.io.WritableComparable; import org.apache.hive.hcatalog.data.HCatRecord; import org.apache.hive.hcatalog.data.schema.HCatFieldSchema; import org.apache.hive.hcatalog.data.schema.HCatSchema; import org.apache.hive.hcatalog.mapreduce.HCatInputFormat; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.rdd.RDD; import com.google.common.base.Preconditions; import lombok.Getter; import lombok.ToString; /** * The Class TableTrainingSpec. */ @ToString public class TableTrainingSpec implements Serializable { /** The Constant LOG. */ public static final Log LOG = LogFactory.getLog(TableTrainingSpec.class); /** The training rdd. */ @Getter private transient RDD<LabeledPoint> trainingRDD; /** The testing rdd. */ @Getter private transient RDD<LabeledPoint> testingRDD; /** The database. */ @Getter private String database; /** The table. */ @Getter private String table; /** The partition filter. */ @Getter private String partitionFilter; /** The feature columns. */ @Getter private List<String> featureColumns; /** The label column. */ @Getter private String labelColumn; /** The conf. */ @Getter private transient HiveConf conf; // By default all samples are considered for training /** The split training. */ private boolean splitTraining; /** The training fraction. */ private double trainingFraction = 1.0; /** The label pos. */ int labelPos; /** The feature positions. */ int[] featurePositions; /** The num features. */ int numFeatures; /** The labeled rdd. */ transient JavaRDD<LabeledPoint> labeledRDD; /** * New builder. * * @return the table training spec builder */ public static TableTrainingSpecBuilder newBuilder() { return new TableTrainingSpecBuilder(); } /** * The Class TableTrainingSpecBuilder. */ public static class TableTrainingSpecBuilder { /** The spec. */ final TableTrainingSpec spec; /** * Instantiates a new table training spec builder. */ public TableTrainingSpecBuilder() { spec = new TableTrainingSpec(); } /** * Hive conf. * * @param conf the conf * @return the table training spec builder */ public TableTrainingSpecBuilder hiveConf(HiveConf conf) { spec.conf = conf; return this; } /** * Database. * * @param db the db * @return the table training spec builder */ public TableTrainingSpecBuilder database(String db) { spec.database = db; return this; } /** * Table. * * @param table the table * @return the table training spec builder */ public TableTrainingSpecBuilder table(String table) { spec.table = table; return this; } /** * Partition filter. * * @param partFilter the part filter * @return the table training spec builder */ public TableTrainingSpecBuilder partitionFilter(String partFilter) { spec.partitionFilter = partFilter; return this; } /** * Label column. * * @param labelColumn the label column * @return the table training spec builder */ public TableTrainingSpecBuilder labelColumn(String labelColumn) { spec.labelColumn = labelColumn; return this; } /** * Feature columns. * * @param featureColumns the feature columns * @return the table training spec builder */ public TableTrainingSpecBuilder featureColumns(List<String> featureColumns) { spec.featureColumns = featureColumns; return this; } /** * Builds the. * * @return the table training spec */ public TableTrainingSpec build() { return spec; } /** * Training fraction. * * @param trainingFraction the training fraction * @return the table training spec builder */ public TableTrainingSpecBuilder trainingFraction(double trainingFraction) { Preconditions.checkArgument(trainingFraction >= 0 && trainingFraction <= 1.0, "Training fraction shoule be between 0 and 1"); spec.trainingFraction = trainingFraction; spec.splitTraining = true; return this; } } /** * The Class DataSample. */ public static class DataSample implements Serializable { /** The labeled point. */ private final LabeledPoint labeledPoint; /** The sample. */ private final double sample; /** * Instantiates a new data sample. * * @param labeledPoint the labeled point */ public DataSample(LabeledPoint labeledPoint) { sample = Math.random(); this.labeledPoint = labeledPoint; } } /** * The Class TrainingFilter. */ public static class TrainingFilter implements Function<DataSample, Boolean> { /** The training fraction. */ private double trainingFraction; /** * Instantiates a new training filter. * * @param fraction the fraction */ public TrainingFilter(double fraction) { trainingFraction = fraction; } /* * (non-Javadoc) * * @see org.apache.spark.api.java.function.Function#call(java.lang.Object) */ @Override public Boolean call(DataSample v1) throws Exception { return v1.sample <= trainingFraction; } } /** * The Class TestingFilter. */ public static class TestingFilter implements Function<DataSample, Boolean> { /** The training fraction. */ private double trainingFraction; /** * Instantiates a new testing filter. * * @param fraction the fraction */ public TestingFilter(double fraction) { trainingFraction = fraction; } /* * (non-Javadoc) * * @see org.apache.spark.api.java.function.Function#call(java.lang.Object) */ @Override public Boolean call(DataSample v1) throws Exception { return v1.sample > trainingFraction; } } /** * The Class GetLabeledPoint. */ public static class GetLabeledPoint implements Function<DataSample, LabeledPoint> { /* * (non-Javadoc) * * @see org.apache.spark.api.java.function.Function#call(java.lang.Object) */ @Override public LabeledPoint call(DataSample v1) throws Exception { return v1.labeledPoint; } } /** * Validate. * * @return true, if successful */ boolean validate() { List<HCatFieldSchema> columns; try { HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter); HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf); columns = tableSchema.getFields(); } catch (IOException exc) { LOG.error("Error getting table info " + toString(), exc); return false; } LOG.info(table + " columns " + columns.toString()); boolean valid = false; if (columns != null && !columns.isEmpty()) { // Check labeled column List<String> columnNames = new ArrayList<String>(); for (HCatFieldSchema col : columns) { columnNames.add(col.getName()); } // Need at least one feature column and one label column valid = columnNames.contains(labelColumn) && columnNames.size() > 1; if (valid) { labelPos = columnNames.indexOf(labelColumn); // Check feature columns if (featureColumns == null || featureColumns.isEmpty()) { // feature columns are not provided, so all columns except label column are feature columns featurePositions = new int[columnNames.size() - 1]; int p = 0; for (int i = 0; i < columnNames.size(); i++) { if (i == labelPos) { continue; } featurePositions[p++] = i; } columnNames.remove(labelPos); featureColumns = columnNames; } else { // Feature columns were provided, verify all feature columns are present in the table valid = columnNames.containsAll(featureColumns); if (valid) { // Get feature positions featurePositions = new int[featureColumns.size()]; for (int i = 0; i < featureColumns.size(); i++) { featurePositions[i] = columnNames.indexOf(featureColumns.get(i)); } } } numFeatures = featureColumns.size(); } } return valid; } /** * Creates the rd ds. * * @param sparkContext the spark context * @throws LensException the lens exception */ public void createRDDs(JavaSparkContext sparkContext) throws LensException { // Validate the spec if (!validate()) { throw new LensException("Table spec not valid: " + toString()); } LOG.info("Creating RDDs with spec " + toString()); // Get the RDD for table JavaPairRDD<WritableComparable, HCatRecord> tableRDD; try { tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter); } catch (IOException e) { throw new LensException(e); } // Map into trainable RDD // TODO: Figure out a way to use custom value mappers FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures]; final DoubleValueMapper doubleMapper = new DoubleValueMapper(); for (int i = 0; i < numFeatures; i++) { valueMappers[i] = doubleMapper; } ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos, numFeatures, 0); labeledRDD = tableRDD.map(trainPrepFunction); if (splitTraining) { // We have to split the RDD between a training RDD and a testing RDD LOG.info("Splitting RDD for table " + database + "." + table + " with split fraction " + trainingFraction); JavaRDD<DataSample> sampledRDD = labeledRDD.map(new Function<LabeledPoint, DataSample>() { @Override public DataSample call(LabeledPoint v1) throws Exception { return new DataSample(v1); } }); trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd(); testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd(); } else { LOG.info("Using same RDD for train and test"); trainingRDD = labeledRDD.rdd(); testingRDD = trainingRDD; } LOG.info("Generated RDDs"); } }