tv.floe.metronome.classification.logisticregression.iterativereduce.POLRWorkerNode.java Source code

Java tutorial

Introduction

Here is the source code for tv.floe.metronome.classification.logisticregression.iterativereduce.POLRWorkerNode.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package tv.floe.metronome.classification.logisticregression.iterativereduce;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.UniformPrior;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

import tv.floe.metronome.classification.logisticregression.POLRModelParameters;
import tv.floe.metronome.classification.logisticregression.ParallelOnlineLogisticRegression;
import tv.floe.metronome.classification.logisticregression.metrics.POLRMetrics;
import tv.floe.metronome.io.records.RCV1RecordFactory;
import tv.floe.metronome.io.records.RecordFactory;

import com.cloudera.iterativereduce.ComputableWorker;
import com.cloudera.iterativereduce.yarn.appworker.ApplicationWorker;

import com.cloudera.iterativereduce.io.RecordParser;
import com.cloudera.iterativereduce.io.TextRecordParser;
import com.google.common.collect.Lists;

/**
 * The Worker node for IterativeReduce - performs work on the shard of input
 * data for the parallel iterative algorithm - runs the SGD algorithm locally on
 * its shard of data
 * 
 * @author jpatterson
 * 
 */
public class POLRWorkerNode extends POLRNodeBase implements ComputableWorker<ParameterVectorUpdatable> {

    private static final Log LOG = LogFactory.getLog(POLRWorkerNode.class);

    int masterTotal = 0;

    public ParallelOnlineLogisticRegression polr = null; // lmp.createRegression();
    public POLRModelParameters polr_modelparams;

    public String internalID = "0";
    private RecordFactory VectorFactory = null;

    private TextRecordParser lineParser = null;

    private boolean IterationComplete = false;
    private int CurrentIteration = 0;

    // basic stats tracking
    POLRMetrics metrics = new POLRMetrics();

    double averageLineCount = 0.0;
    int k = 0;
    double step = 0.0;
    int[] bumps = new int[] { 1, 2, 5 };
    double lineCount = 0;

    /**
     * Sends a full copy of the multinomial logistic regression array of parameter
     * vectors to the master - this method plugs the local parameter vector into
     * the message
     */
    public ParameterVector GenerateUpdate() {

        ParameterVector gradient = new ParameterVector();
        gradient.parameter_vector = this.polr.getBeta().clone(); // this.polr.getGamma().getMatrix().clone();
        gradient.SrcWorkerPassCount = this.LocalBatchCountForIteration;

        if (this.lineParser.hasMoreRecords()) {
            gradient.IterationComplete = 0;
        } else {
            gradient.IterationComplete = 1;
        }

        gradient.CurrentIteration = this.CurrentIteration;

        gradient.AvgLogLikelihood = (new Double(metrics.AvgLogLikelihood)).floatValue();
        gradient.PercentCorrect = (new Double(metrics.AvgCorrect * 100)).floatValue();
        gradient.TrainedRecords = (new Long(metrics.TotalRecordsProcessed)).intValue();

        return gradient;

    }

    /**
     * The IR::Compute method - this is where we do the next batch of records for
     * SGD
     */
    @Override
    public ParameterVectorUpdatable compute() {

        Text value = new Text();
        long batch_vec_factory_time = 0;

        boolean result = true;
        //boolean processBatch = false;

        while (this.lineParser.hasMoreRecords()) {

            try {
                result = this.lineParser.next(value);
            } catch (IOException e1) {
                // TODO Auto-generated catch block
                e1.printStackTrace();
            }

            if (result) {

                long startTime = System.currentTimeMillis();

                Vector v = new RandomAccessSparseVector(this.FeatureVectorSize);
                int actual = -1;
                try {

                    actual = this.VectorFactory.processLine(value.toString(), v);
                } catch (Exception e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }

                long endTime = System.currentTimeMillis();

                batch_vec_factory_time += (endTime - startTime);

                // calc stats ---------

                double mu = Math.min(k + 1, 200);
                double ll = this.polr.logLikelihood(actual, v);

                metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu;

                if (Double.isNaN(metrics.AvgLogLikelihood)) {
                    metrics.AvgLogLikelihood = 0;
                }

                Vector p = new DenseVector(this.num_categories);
                this.polr.classifyFull(p, v);
                int estimated = p.maxValueIndex();
                int correct = (estimated == actual ? 1 : 0);
                metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu;
                this.polr.train(actual, v);

                k++;
                metrics.TotalRecordsProcessed = k;
                //          if (x == this.BatchSize - 1) {

                /*            System.err
                                .printf(
                "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n",
                this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood,
                metrics.AvgCorrect * 100, batch_vec_factory_time);
                  */
                //          }

                this.polr.close();

            } else {

                //          this.LocalBatchCountForIteration++;
                // this.input_split.ResetToStartOfSplit();
                // nothing else to process in split!
                //          break;

            } // if

        } // for the batch size

        System.err.printf(
                "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n",
                this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100,
                batch_vec_factory_time);

        /*    } else {
              System.err
              .printf(
                  "Worker %s:\t Trained Recs: %10d,  AvgLL: %10.3f, Percent Correct: %10.2f, [Done With Iteration]\n",
                  this.internalID, k, metrics.AvgLogLikelihood,
                  metrics.AvgCorrect * 100);
                  
            } // if 
          */
        return new ParameterVectorUpdatable(this.GenerateUpdate());
    }

    public ParameterVectorUpdatable getResults() {
        return new ParameterVectorUpdatable(GenerateUpdate());
    }

    /**
     * This is called when we recieve an update from the master
     * 
     * here we - replace the gradient vector with the new global gradient vector
     * 
     */
    @Override
    public void update(ParameterVectorUpdatable t) {
        // masterTotal = t.get();
        ParameterVector global_update = t.get();

        // set the local parameter vector to the global aggregate ("beta")
        this.polr.SetBeta(global_update.parameter_vector);

        // update global count
        this.GlobalBatchCountForIteration = global_update.GlobalPassCount;

        // flush the local gradient delta buffer ("gamma")
        //    this.polr.FlushGamma();

        /*    if (global_update.IterationComplete == 0) {
              this.IterationComplete = false;
            } else {
              this.IterationComplete = true;
            
                  
              // when this happens, it will trip the ApplicationWorkerService loop and iteration will increment
                  
            }
          */
    }

    @Override
    public void setup(Configuration c) {

        this.conf = c;

        try {

            this.num_categories = this.conf.getInt("com.cloudera.knittingboar.setup.numCategories", 2);

            // feature vector size

            this.FeatureVectorSize = LoadIntConfVarOrException("com.cloudera.knittingboar.setup.FeatureVectorSize",
                    "Error loading config: could not load feature vector size");

            // feature vector size
            //      this.BatchSize = this.conf.getInt(
            //          "com.cloudera.knittingboar.setup.BatchSize", 200);

            //      this.NumberPasses = this.conf.getInt(
            //          "com.cloudera.knittingboar.setup.NumberPasses", 1);
            // app.iteration.count
            this.NumberIterations = this.conf.getInt("app.iteration.count", 1);

            // protected double Lambda = 1.0e-4;
            this.Lambda = Double.parseDouble(this.conf.get("com.cloudera.knittingboar.setup.Lambda", "1.0e-4"));

            // protected double LearningRate = 50;
            this.LearningRate = Double
                    .parseDouble(this.conf.get("com.cloudera.knittingboar.setup.LearningRate", "10"));

            // maps to either CSV, 20newsgroups, or RCV1
            this.RecordFactoryClassname = LoadStringConfVarOrException(
                    "com.cloudera.knittingboar.setup.RecordFactoryClassname",
                    "Error loading config: could not load RecordFactory classname");

            if (this.RecordFactoryClassname.equals(RecordFactory.CSV_RECORDFACTORY)) {

                // so load the CSV specific stuff ----------

                // predictor label names
                this.PredictorLabelNames = LoadStringConfVarOrException(
                        "com.cloudera.knittingboar.setup.PredictorLabelNames",
                        "Error loading config: could not load predictor label names");

                // predictor var types
                this.PredictorVariableTypes = LoadStringConfVarOrException(
                        "com.cloudera.knittingboar.setup.PredictorVariableTypes",
                        "Error loading config: could not load predictor variable types");

                // target variables
                this.TargetVariableName = LoadStringConfVarOrException(
                        "com.cloudera.knittingboar.setup.TargetVariableName",
                        "Error loading config: Target Variable Name");

                // column header names
                this.ColumnHeaderNames = LoadStringConfVarOrException(
                        "com.cloudera.knittingboar.setup.ColumnHeaderNames",
                        "Error loading config: Column Header Names");

                // System.out.println("LoadConfig(): " + this.ColumnHeaderNames);

            }

        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        this.SetupPOLR();
    }

    /**
     * TODO:
     * - throw a mis-configuration exception of some sort
     * 
     */
    private void SetupPOLR() {

        // do splitting strings into arrays here...
        String[] predictor_label_names = this.PredictorLabelNames.split(",");
        String[] variable_types = this.PredictorVariableTypes.split(",");

        polr_modelparams = new POLRModelParameters();
        polr_modelparams.setTargetVariable(this.TargetVariableName);
        polr_modelparams.setNumFeatures(this.FeatureVectorSize);
        polr_modelparams.setUseBias(true);

        List<String> typeList = Lists.newArrayList();
        for (int x = 0; x < variable_types.length; x++) {
            typeList.add(variable_types[x]);
        }

        List<String> predictorList = Lists.newArrayList();
        for (int x = 0; x < predictor_label_names.length; x++) {
            predictorList.add(predictor_label_names[x]);
        }

        // where do these come from?
        polr_modelparams.setTypeMap(predictorList, typeList);
        polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                                 // command line
        polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                             // match command line

        // setup record factory stuff here ---------

        // ####### disabled this input format, was not a long term solution ##########
        /*    
        if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY
            .equals(this.RecordFactoryClassname)) {
              
          this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");
              
        } else */

        if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

            this.VectorFactory = new RCV1RecordFactory();

        } else {

            // it defaults to the CSV record factor, but a custom one
            /*      
                  this.VectorFactory = new CSVBasedDatasetRecordFactory(
                      this.TargetVariableName, polr_modelparams.getTypeMap());
                      
                  ((CSVBasedDatasetRecordFactory) this.VectorFactory)
                      .firstLine(this.ColumnHeaderNames);
              */

            //   throw new Exception("Invalid Record Factory Class");

        }

        polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

        // ----- this normally is generated from the POLRModelParams ------

        this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
                new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                        .learningRate(this.LearningRate);

        polr_modelparams.setPOLR(polr);

        // this.bSetup = true;
    }

    @Override
    public void setRecordParser(RecordParser r) {
        this.lineParser = (TextRecordParser) r;
    }

    /**
     * only implemented for completeness with the interface, we argued over how to
     * implement this. - this is currently a legacy artifact
     */
    @Override
    public ParameterVectorUpdatable compute(List<ParameterVectorUpdatable> records) {
        // TODO Auto-generated method stub
        return compute();
    }

    public static void main(String[] args) throws Exception {
        TextRecordParser parser = new TextRecordParser();
        POLRWorkerNode pwn = new POLRWorkerNode();
        ApplicationWorker<ParameterVectorUpdatable> aw = new ApplicationWorker<ParameterVectorUpdatable>(parser,
                pwn, ParameterVectorUpdatable.class);

        ToolRunner.run(aw, args);
    }

    /*  @Override
      public int getCurrentGlobalIteration() {
        // TODO Auto-generated method stub
        return 0;
      }
    */

    /**
     * returns false if we're done with iterating over the data  
     * 
     * @return
     */
    @Override
    public boolean IncrementIteration() {

        this.CurrentIteration++;
        this.IterationComplete = false;
        this.lineParser.reset();

        System.out.println("IncIteration > " + this.CurrentIteration + ", " + this.NumberIterations);

        if (this.CurrentIteration >= this.NumberIterations) {
            System.out.println("POLRWorkerNode: [ done with all iterations ]");
            return false;
        }

        return true;

    }

    /*  @Override
      public boolean isStillWorkingOnCurrentIteration() {
        
            
        //return this.lineParser.hasMoreRecords();
            
        //return this.
        return !this.IterationComplete;
      }
      */
}