Java examples for Big Data:apache spark
Random Forest Regression via apache spark
import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.List; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.regression.RandomForestRegressionModel; import org.apache.spark.ml.regression.RandomForestRegressor; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class RandomForestRegression { public static void main(String[] args) throws IOException { SparkConf conf = new SparkConf().setAppName( "RandomForestRegression").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(sc); DataFrame traindata = sqlContext.read().format("libsvm") .load("finalTrainFiles/*"); VectorIndexerModel featureIndexer = new VectorIndexer() .setInputCol("features").setOutputCol("indexedFeatures") .setMaxCategories(4).fit(traindata); DataFrame[] splits = traindata//from ww w.j a va 2 s .com .randomSplit(new double[] { 0.8, 0.2 }); DataFrame train = splits[0]; DataFrame validate = splits[1]; RandomForestRegressor rf = new RandomForestRegressor().setLabelCol( "label").setFeaturesCol("indexedFeatures"); Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { featureIndexer, rf }); PipelineModel model = pipeline.fit(train); DataFrame predictions = model.transform(validate); List<Row> mylist = predictions.select("prediction", "label") .collectAsList(); StringBuilder sb; for (Row r : mylist) { sb = new StringBuilder(); String entry = r.toString(); sb.append(entry.substring(1, entry.length() - 1)) .append("\r\n"); FileWriter writer = new FileWriter("prediction.txt", true); writer.append(sb.toString()); writer.flush(); writer.close(); } RegressionEvaluator evaluator = new RegressionEvaluator() .setLabelCol("label").setPredictionCol("prediction") .setMetricName("rmse"); double rmse = evaluator.evaluate(predictions); System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); RandomForestRegressionModel rfModel = (RandomForestRegressionModel) (model .stages()[1]); } }