Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.lambda.ml; import com.anhth12.lambda.BatchLayerUpdate; import com.anhth12.lambda.TopicProducer; import com.anhth12.lambda.common.collection.Pair; import com.anhth12.lambda.common.pmml.PMMLUtils; import com.anhth12.lambda.common.random.RandomManager; import com.anhth12.lambda.fn.Functions; import com.anhth12.lambda.ml.param.HyperParamValues; import com.anhth12.lambda.ml.param.HyperParams; import com.google.common.base.Preconditions; import com.typesafe.config.Config; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.rdd.RDD; import org.dmg.pmml.PMML; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; /** * * @author Tong Hoang Anh * @param <M> */ public abstract class MLUpdate<M> implements BatchLayerUpdate<String, M, String> { public static final Logger log = LoggerFactory.getLogger(MLUpdate.class); public static final String MODEL_FILE_NAME = "model.pmml.gz"; private final double testFraction; private final int candidates; private final int evalParallelism; protected MLUpdate(Config config) { this.testFraction = config.getDouble("lambda.ml.eval.test-fraction"); this.candidates = config.getInt("lambda.ml.eval.candidates"); this.evalParallelism = config.getInt("lambda.ml.eval.parallism"); Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0); Preconditions.checkArgument(candidates > 0); Preconditions.checkArgument(evalParallelism > 0); } protected final double getTestFraction() { return testFraction; } public List<HyperParamValues<?>> getHyperParamValues() { return Collections.emptyList(); } public abstract PMML buildModel(JavaSparkContext sparkContext, JavaRDD<M> trainData, List<?> hyperParameter, Path cadidatePath); public void publishAdditionalModelData(JavaSparkContext sparkContext, PMML pmml, JavaRDD<M> newData, JavaRDD<M> pastData, Path modelParentPath, TopicProducer<String, String> modelUpdateTopic) { //nothing to do by default } public abstract double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath, JavaRDD<M> testData, JavaRDD<M> trainData); @Override public void runUpdate(JavaSparkContext sparkContext, long timestamp, JavaPairRDD<String, M> newKeyMessageData, JavaPairRDD<String, M> pastKeyMessageData, String modelDirString, TopicProducer<String, String> modelUpdateTopic) throws IOException, InterruptedException { Preconditions.checkNotNull(newKeyMessageData); JavaRDD<M> newData = newKeyMessageData.values(); JavaRDD<M> pastData = pastKeyMessageData == null ? null : pastKeyMessageData.values(); if (newData != null) { newData.cache(); newData.foreachPartition(Functions.<Iterator<M>>noOp()); } if (pastData != null) { pastData.cache(); pastData.foreachPartition(Functions.<Iterator<M>>noOp()); } List<HyperParamValues<?>> hyperParamValues = getHyperParamValues(); int valuesPerHyperParam = HyperParams.chooseValuesPerHyperParam(hyperParamValues.size(), candidates); List<List<?>> hyperParameterCombos = HyperParams.chooseHyperParameterCombos(hyperParamValues, candidates, valuesPerHyperParam); FileSystem fs = FileSystem.get(sparkContext.hadoopConfiguration()); Path modelDir = new Path(modelDirString); Path tempModelPath = new Path(modelDir, ".temporary"); Path candiatesPath = new Path(tempModelPath, Long.toString(System.currentTimeMillis())); fs.mkdirs(candiatesPath); Path bestCandidatePath = findBestCandidatePath(sparkContext, newData, pastData, hyperParameterCombos, candiatesPath); Path finalPath = new Path(modelDir, Long.toString(System.currentTimeMillis())); if (bestCandidatePath == null) { log.info("Unable to build any model"); } else { fs.rename(bestCandidatePath, finalPath); } fs.delete(candiatesPath, true); Path bestModelPath = new Path(finalPath, MODEL_FILE_NAME); if (fs.exists(bestModelPath)) { PMML bestModel; try (InputStream in = new GZIPInputStream(fs.open(finalPath), 1 << 16)) { bestModel = PMMLUtils.read(in); } modelUpdateTopic.send("MODEL", PMMLUtils.toString(bestModel)); publishAdditionalModelData(sparkContext, bestModel, newData, pastData, candiatesPath, modelUpdateTopic); } if (newData != null) { newData.unpersist(); } if (pastData != null) { pastData.unpersist(); } } private Path findBestCandidatePath(JavaSparkContext sparkContext, JavaRDD<M> newData, JavaRDD<M> pastData, List<List<?>> hyperParameterCombos, Path candiatesPath) throws InterruptedException, IOException { Map<Path, Double> pathToEval = new HashMap<>(candidates); if (evalParallelism > 1) { Collection<Future<Tuple2<Path, Double>>> futures = new ArrayList<>(candidates); ExecutorService executor = Executors.newFixedThreadPool(evalParallelism); try { for (int i = 0; i < candidates; i++) { futures.add(executor.submit(new BuildAndEvalWorker(i, hyperParameterCombos, sparkContext, newData, pastData, candiatesPath))); } } finally { executor.shutdown(); } for (Future<Tuple2<Path, Double>> future : futures) { Tuple2<Path, Double> pathEval; try { pathEval = future.get(); } catch (ExecutionException ex) { throw new IllegalStateException(ex); } pathToEval.put(pathEval._1, pathEval._2); } } else { for (int i = 0; i < candidates; i++) { Tuple2<Path, Double> pathEval = new BuildAndEvalWorker(i, hyperParameterCombos, sparkContext, newData, pastData, candiatesPath).call(); pathToEval.put(pathEval._1, pathEval._2); } } FileSystem fs = FileSystem.get(sparkContext.hadoopConfiguration()); Path bestCandidatePath = null; double bestEval = Double.NEGATIVE_INFINITY; for (Map.Entry<Path, Double> pathEval : pathToEval.entrySet()) { Path path = pathEval.getKey(); Double eval = pathEval.getValue(); if ((bestCandidatePath == null) || (eval != null && eval > bestEval) && fs.exists(path)) { log.info("Best eval / path is now {} / {}", eval, path); if (eval != null) { bestEval = eval; } bestCandidatePath = path; } } return bestCandidatePath; } final class BuildAndEvalWorker implements Callable<Tuple2<Path, Double>> { private final int i; private final List<List<?>> hyperParameterCombos; private final JavaSparkContext sparkContext; private final JavaRDD<M> newData; private final JavaRDD<M> pastData; private final Path candidatesPath; public BuildAndEvalWorker(int i, List<List<?>> hyperParameterCombos, JavaSparkContext sparkContext, JavaRDD<M> newData, JavaRDD<M> pastData, Path candidatePath) { this.i = i; this.hyperParameterCombos = hyperParameterCombos; this.sparkContext = sparkContext; this.newData = newData; this.pastData = pastData; this.candidatesPath = candidatePath; } @Override public Tuple2<Path, Double> call() throws IOException { List<?> hyperParameters = hyperParameterCombos.get(i % hyperParameterCombos.size()); Path candidatePath = new Path(candidatesPath, Integer.toString(i)); log.info("Building candidate {} with param {}", i, hyperParameters); Pair<JavaRDD<M>, JavaRDD<M>> trainTestData = splitTrainTest(newData, pastData); JavaRDD<M> allTrainData = trainTestData.getFirst(); JavaRDD<M> testDta = trainTestData.getSecond(); Double eval = null; if (empty(allTrainData)) { log.info("No train data to build model"); } else { PMML model = buildModel(sparkContext, allTrainData, hyperParameters, candidatePath); if (model == null) { log.info("Unable to build a model"); } else { Path modelPath = new Path(candidatePath, MODEL_FILE_NAME); log.info("Writing model to {}", modelPath); FileSystem fs = FileSystem.get(sparkContext.hadoopConfiguration()); fs.mkdirs(candidatePath); try (OutputStream out = new GZIPOutputStream(fs.create(modelPath), 1 << 16)) { PMMLUtils.write(model, out); } if (empty(testDta)) { log.info("No test data available to evaluate model"); } else { log.info("Evaluating model"); double thisEval = evaluate(sparkContext, model, candidatePath, testDta, allTrainData); eval = Double.isNaN(thisEval) ? null : thisEval; } } } log.info("Model eval for params {}: {} ({})", hyperParameters, eval, candidatePath); return new Tuple2<>(candidatePath, eval); } } private Pair<JavaRDD<M>, JavaRDD<M>> splitTrainTest(JavaRDD<M> newData, JavaRDD<M> pastData) { Preconditions.checkNotNull(newData); if (testFraction <= 0.0) { return new Pair<>(pastData == null ? newData : newData.union(pastData), null); } if (testFraction >= 1.0) { return new Pair<>(pastData, newData); } if (empty(newData)) { return new Pair<>(pastData, null); } Pair<JavaRDD<M>, JavaRDD<M>> newTrainTest = splitNewDataToTrainTest(newData); JavaRDD<M> newTrainData = newTrainTest.getFirst(); return new Pair<>(pastData == null ? newTrainData : newTrainData.union(pastData), newTrainTest.getSecond()); } private Pair<JavaRDD<M>, JavaRDD<M>> splitNewDataToTrainTest(JavaRDD<M> newData) { RDD<M>[] testTrainRDDs = newData.rdd().randomSplit(new double[] { 1.0 - testFraction, testFraction }, RandomManager.getRandom().nextLong()); return new Pair<>(newData.wrapRDD(testTrainRDDs[0]), newData.wrapRDD(testTrainRDDs[1])); } private static boolean empty(JavaRDD<?> rdd) { return rdd == null || rdd.take(1).isEmpty(); } }