com.anhth12.lambda.ml.MLUpdate.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.lambda.ml.MLUpdate.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.lambda.ml;

import com.anhth12.lambda.BatchLayerUpdate;
import com.anhth12.lambda.TopicProducer;
import com.anhth12.lambda.common.collection.Pair;
import com.anhth12.lambda.common.pmml.PMMLUtils;
import com.anhth12.lambda.common.random.RandomManager;
import com.anhth12.lambda.fn.Functions;
import com.anhth12.lambda.ml.param.HyperParamValues;
import com.anhth12.lambda.ml.param.HyperParams;
import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/**
 *
 * @author Tong Hoang Anh
 * @param <M>
 */
public abstract class MLUpdate<M> implements BatchLayerUpdate<String, M, String> {

    public static final Logger log = LoggerFactory.getLogger(MLUpdate.class);

    public static final String MODEL_FILE_NAME = "model.pmml.gz";

    private final double testFraction;
    private final int candidates;
    private final int evalParallelism;

    protected MLUpdate(Config config) {
        this.testFraction = config.getDouble("lambda.ml.eval.test-fraction");
        this.candidates = config.getInt("lambda.ml.eval.candidates");
        this.evalParallelism = config.getInt("lambda.ml.eval.parallism");

        Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0);
        Preconditions.checkArgument(candidates > 0);
        Preconditions.checkArgument(evalParallelism > 0);
    }

    protected final double getTestFraction() {
        return testFraction;
    }

    public List<HyperParamValues<?>> getHyperParamValues() {
        return Collections.emptyList();
    }

    public abstract PMML buildModel(JavaSparkContext sparkContext, JavaRDD<M> trainData, List<?> hyperParameter,
            Path cadidatePath);

    public void publishAdditionalModelData(JavaSparkContext sparkContext, PMML pmml, JavaRDD<M> newData,
            JavaRDD<M> pastData, Path modelParentPath, TopicProducer<String, String> modelUpdateTopic) {
        //nothing to do by default
    }

    public abstract double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath,
            JavaRDD<M> testData, JavaRDD<M> trainData);

    @Override
    public void runUpdate(JavaSparkContext sparkContext, long timestamp, JavaPairRDD<String, M> newKeyMessageData,
            JavaPairRDD<String, M> pastKeyMessageData, String modelDirString,
            TopicProducer<String, String> modelUpdateTopic) throws IOException, InterruptedException {

        Preconditions.checkNotNull(newKeyMessageData);

        JavaRDD<M> newData = newKeyMessageData.values();
        JavaRDD<M> pastData = pastKeyMessageData == null ? null : pastKeyMessageData.values();

        if (newData != null) {
            newData.cache();
            newData.foreachPartition(Functions.<Iterator<M>>noOp());
        }
        if (pastData != null) {
            pastData.cache();
            pastData.foreachPartition(Functions.<Iterator<M>>noOp());
        }

        List<HyperParamValues<?>> hyperParamValues = getHyperParamValues();

        int valuesPerHyperParam = HyperParams.chooseValuesPerHyperParam(hyperParamValues.size(), candidates);

        List<List<?>> hyperParameterCombos = HyperParams.chooseHyperParameterCombos(hyperParamValues, candidates,
                valuesPerHyperParam);

        FileSystem fs = FileSystem.get(sparkContext.hadoopConfiguration());

        Path modelDir = new Path(modelDirString);
        Path tempModelPath = new Path(modelDir, ".temporary");
        Path candiatesPath = new Path(tempModelPath, Long.toString(System.currentTimeMillis()));
        fs.mkdirs(candiatesPath);

        Path bestCandidatePath = findBestCandidatePath(sparkContext, newData, pastData, hyperParameterCombos,
                candiatesPath);

        Path finalPath = new Path(modelDir, Long.toString(System.currentTimeMillis()));
        if (bestCandidatePath == null) {
            log.info("Unable to build any model");
        } else {
            fs.rename(bestCandidatePath, finalPath);
        }

        fs.delete(candiatesPath, true);

        Path bestModelPath = new Path(finalPath, MODEL_FILE_NAME);

        if (fs.exists(bestModelPath)) {
            PMML bestModel;
            try (InputStream in = new GZIPInputStream(fs.open(finalPath), 1 << 16)) {
                bestModel = PMMLUtils.read(in);
            }

            modelUpdateTopic.send("MODEL", PMMLUtils.toString(bestModel));
            publishAdditionalModelData(sparkContext, bestModel, newData, pastData, candiatesPath, modelUpdateTopic);
        }

        if (newData != null) {
            newData.unpersist();
        }

        if (pastData != null) {
            pastData.unpersist();
        }

    }

    private Path findBestCandidatePath(JavaSparkContext sparkContext, JavaRDD<M> newData, JavaRDD<M> pastData,
            List<List<?>> hyperParameterCombos, Path candiatesPath) throws InterruptedException, IOException {

        Map<Path, Double> pathToEval = new HashMap<>(candidates);
        if (evalParallelism > 1) {
            Collection<Future<Tuple2<Path, Double>>> futures = new ArrayList<>(candidates);
            ExecutorService executor = Executors.newFixedThreadPool(evalParallelism);

            try {
                for (int i = 0; i < candidates; i++) {
                    futures.add(executor.submit(new BuildAndEvalWorker(i, hyperParameterCombos, sparkContext,
                            newData, pastData, candiatesPath)));
                }
            } finally {
                executor.shutdown();
            }

            for (Future<Tuple2<Path, Double>> future : futures) {
                Tuple2<Path, Double> pathEval;
                try {
                    pathEval = future.get();
                } catch (ExecutionException ex) {
                    throw new IllegalStateException(ex);
                }
                pathToEval.put(pathEval._1, pathEval._2);
            }
        } else {
            for (int i = 0; i < candidates; i++) {
                Tuple2<Path, Double> pathEval = new BuildAndEvalWorker(i, hyperParameterCombos, sparkContext,
                        newData, pastData, candiatesPath).call();
                pathToEval.put(pathEval._1, pathEval._2);
            }
        }

        FileSystem fs = FileSystem.get(sparkContext.hadoopConfiguration());

        Path bestCandidatePath = null;

        double bestEval = Double.NEGATIVE_INFINITY;

        for (Map.Entry<Path, Double> pathEval : pathToEval.entrySet()) {
            Path path = pathEval.getKey();
            Double eval = pathEval.getValue();

            if ((bestCandidatePath == null) || (eval != null && eval > bestEval) && fs.exists(path)) {
                log.info("Best eval / path is now {} / {}", eval, path);
                if (eval != null) {
                    bestEval = eval;
                }
                bestCandidatePath = path;
            }
        }

        return bestCandidatePath;
    }

    final class BuildAndEvalWorker implements Callable<Tuple2<Path, Double>> {

        private final int i;
        private final List<List<?>> hyperParameterCombos;
        private final JavaSparkContext sparkContext;
        private final JavaRDD<M> newData;
        private final JavaRDD<M> pastData;
        private final Path candidatesPath;

        public BuildAndEvalWorker(int i, List<List<?>> hyperParameterCombos, JavaSparkContext sparkContext,
                JavaRDD<M> newData, JavaRDD<M> pastData, Path candidatePath) {
            this.i = i;
            this.hyperParameterCombos = hyperParameterCombos;
            this.sparkContext = sparkContext;
            this.newData = newData;
            this.pastData = pastData;
            this.candidatesPath = candidatePath;
        }

        @Override
        public Tuple2<Path, Double> call() throws IOException {
            List<?> hyperParameters = hyperParameterCombos.get(i % hyperParameterCombos.size());
            Path candidatePath = new Path(candidatesPath, Integer.toString(i));
            log.info("Building candidate {} with param {}", i, hyperParameters);

            Pair<JavaRDD<M>, JavaRDD<M>> trainTestData = splitTrainTest(newData, pastData);
            JavaRDD<M> allTrainData = trainTestData.getFirst();
            JavaRDD<M> testDta = trainTestData.getSecond();

            Double eval = null;
            if (empty(allTrainData)) {
                log.info("No train data to build model");
            } else {
                PMML model = buildModel(sparkContext, allTrainData, hyperParameters, candidatePath);
                if (model == null) {
                    log.info("Unable to build a model");
                } else {
                    Path modelPath = new Path(candidatePath, MODEL_FILE_NAME);
                    log.info("Writing model to {}", modelPath);
                    FileSystem fs = FileSystem.get(sparkContext.hadoopConfiguration());
                    fs.mkdirs(candidatePath);
                    try (OutputStream out = new GZIPOutputStream(fs.create(modelPath), 1 << 16)) {
                        PMMLUtils.write(model, out);
                    }
                    if (empty(testDta)) {
                        log.info("No test data available to evaluate model");
                    } else {
                        log.info("Evaluating model");
                        double thisEval = evaluate(sparkContext, model, candidatePath, testDta, allTrainData);
                        eval = Double.isNaN(thisEval) ? null : thisEval;
                    }
                }
            }

            log.info("Model eval for params {}: {} ({})", hyperParameters, eval, candidatePath);
            return new Tuple2<>(candidatePath, eval);

        }

    }

    private Pair<JavaRDD<M>, JavaRDD<M>> splitTrainTest(JavaRDD<M> newData, JavaRDD<M> pastData) {
        Preconditions.checkNotNull(newData);
        if (testFraction <= 0.0) {
            return new Pair<>(pastData == null ? newData : newData.union(pastData), null);
        }
        if (testFraction >= 1.0) {
            return new Pair<>(pastData, newData);
        }

        if (empty(newData)) {
            return new Pair<>(pastData, null);
        }

        Pair<JavaRDD<M>, JavaRDD<M>> newTrainTest = splitNewDataToTrainTest(newData);
        JavaRDD<M> newTrainData = newTrainTest.getFirst();
        return new Pair<>(pastData == null ? newTrainData : newTrainData.union(pastData), newTrainTest.getSecond());
    }

    private Pair<JavaRDD<M>, JavaRDD<M>> splitNewDataToTrainTest(JavaRDD<M> newData) {
        RDD<M>[] testTrainRDDs = newData.rdd().randomSplit(new double[] { 1.0 - testFraction, testFraction },
                RandomManager.getRandom().nextLong());
        return new Pair<>(newData.wrapRDD(testTrainRDDs[0]), newData.wrapRDD(testTrainRDDs[1]));
    }

    private static boolean empty(JavaRDD<?> rdd) {
        return rdd == null || rdd.take(1).isEmpty();
    }

}