Random Forest Regression via apache spark - Java Big Data

Java examples for Big Data:apache spark

Description

Random Forest Regression via apache spark

Demo Code


import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressor;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

public class RandomForestRegression {

    public static void main(String[] args) throws IOException {

        SparkConf conf = new SparkConf().setAppName(
                "RandomForestRegression").setMaster("local");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
        DataFrame traindata = sqlContext.read().format("libsvm")
                .load("finalTrainFiles/*");
        VectorIndexerModel featureIndexer = new VectorIndexer()
                .setInputCol("features").setOutputCol("indexedFeatures")
                .setMaxCategories(4).fit(traindata);

        DataFrame[] splits = traindata//from   ww  w.j  a  va  2  s  .com
                .randomSplit(new double[] { 0.8, 0.2 });
        DataFrame train = splits[0];
        DataFrame validate = splits[1];

        RandomForestRegressor rf = new RandomForestRegressor().setLabelCol(
                "label").setFeaturesCol("indexedFeatures");

        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {
                featureIndexer, rf });

        PipelineModel model = pipeline.fit(train);

        DataFrame predictions = model.transform(validate);

        List<Row> mylist = predictions.select("prediction", "label")
                .collectAsList();
        StringBuilder sb;
        for (Row r : mylist) {
            sb = new StringBuilder();
            String entry = r.toString();
            sb.append(entry.substring(1, entry.length() - 1))
                    .append("\r\n");
            FileWriter writer = new FileWriter("prediction.txt", true);

            writer.append(sb.toString());
            writer.flush();
            writer.close();

        }
        RegressionEvaluator evaluator = new RegressionEvaluator()
                .setLabelCol("label").setPredictionCol("prediction")
                .setMetricName("rmse");
        double rmse = evaluator.evaluate(predictions);
        System.out.println("Root Mean Squared Error (RMSE) on test data = "
                + rmse);

        RandomForestRegressionModel rfModel = (RandomForestRegressionModel) (model
                .stages()[1]);
    }

}

Related Tutorials