com.cloudera.knittingboar.sgd.POLRMasterDriver.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.knittingboar.sgd.POLRMasterDriver.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.cloudera.knittingboar.sgd;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.L2;
import org.apache.mahout.classifier.sgd.UniformPrior;

import com.cloudera.knittingboar.messages.GlobalParameterVectorUpdateMessage;
import com.cloudera.knittingboar.messages.GradientUpdateMessage;
import com.cloudera.knittingboar.records.CSVBasedDatasetRecordFactory;
import com.cloudera.knittingboar.records.RCV1RecordFactory;
import com.cloudera.knittingboar.records.RecordFactory;
import com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory;
import com.cloudera.knittingboar.utils.Utils;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import com.mongodb.util.Util;

/**
 * Runs as the master node in this parallel iterative algorithm - collects
 * gradient updates from slave nodes - updates its locally held global parameter
 * vector 
 * 
 * - sends a copy back to the slave node to update its own parameter
 * vector 
 * 
 * - manages execution of the whole POLR process
 * 
 * - this is the basic simulated version of the POLR master
 * 
 * @author jpatterson
 * 
 */
public class POLRMasterDriver extends POLRBaseDriver {

    GradientBuffer global_parameter_vector = null;
    private ArrayList<GlobalParameterVectorUpdateMessage> outgoing_parameter_updates = new ArrayList<GlobalParameterVectorUpdateMessage>();
    private ArrayList<GradientUpdateMessage> incoming_gradient_updates = new ArrayList<GradientUpdateMessage>();

    private int GlobalMaxPassCount = 0;

    // these are only used for saving the model
    public ParallelOnlineLogisticRegression polr = null;
    public POLRModelParameters polr_modelparams;
    private RecordFactory VectorFactory = null;

    public POLRMasterDriver() {

    }

    /**
     * Take the newly loaded config junk and setup the local data structures
     * 
     */
    public void Setup() {

        this.global_parameter_vector = new GradientBuffer(this.num_categories, this.FeatureVectorSize);

        String[] predictor_label_names = this.PredictorLabelNames.split(",");

        String[] variable_types = this.PredictorVariableTypes.split(",");

        polr_modelparams = new POLRModelParameters();
        polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine,
                                                                     // target));
        polr_modelparams.setNumFeatures(this.FeatureVectorSize);
        polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias));

        List<String> typeList = Lists.newArrayList();
        for (int x = 0; x < variable_types.length; x++) {
            typeList.add(variable_types[x]);
        }

        List<String> predictorList = Lists.newArrayList();
        for (int x = 0; x < predictor_label_names.length; x++) {
            predictorList.add(predictor_label_names[x]);
        }

        polr_modelparams.setTypeMap(predictorList, typeList);
        polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                                 // command line
        polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                             // match command line

        // setup record factory stuff here ---------

        if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

            this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");

        } else if (RecordFactory.RCV1_RECORDFACTORY.equals(this.RecordFactoryClassname)) {

            this.VectorFactory = new RCV1RecordFactory();

        } else {

            // need to rethink this

            this.VectorFactory = new CSVBasedDatasetRecordFactory(this.TargetVariableName,
                    polr_modelparams.getTypeMap());

            ((CSVBasedDatasetRecordFactory) this.VectorFactory).firstLine(this.ColumnHeaderNames);

        }

        polr_modelparams.setTargetCategories(this.VectorFactory.getTargetCategories());

        // ----- this normally is generated from the POLRModelParams ------

        // this.polr = new ParallelOnlineLogisticRegression(this.num_categories,
        // this.FeatureVectorSize, new L1())
        this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize,
                new UniformPrior()).alpha(1).stepOffset(1000).decayExponent(0.9).lambda(this.Lambda)
                        .learningRate(this.LearningRate);

        polr_modelparams.setPOLR(polr);
        this.bSetup = true;

    }

    public void Start() {

        this.bRunning = true;

    }

    public void Stop() {

        this.bRunning = false;

    }

    public void AddIncomingGradientMessageToQueue(GradientUpdateMessage e) {

        this.incoming_gradient_updates.add(e);

    }

    public void RecvGradientMessage() {

        GradientUpdateMessage rcvd_msg = this.incoming_gradient_updates.remove(0);

        if (rcvd_msg.SrcWorkerPassCount > this.GlobalMaxPassCount) {

            this.GlobalMaxPassCount = rcvd_msg.SrcWorkerPassCount;

        }

        // accumulate gradient
        this.MergeGradientUpdate(rcvd_msg.gradient);

    }

    public void GenerateGlobalUpdateVector() {

        // post message back to sender async
        GlobalParameterVectorUpdateMessage response_msg = new GlobalParameterVectorUpdateMessage("",
                this.num_categories, this.FeatureVectorSize);
        response_msg.parameter_vector = this.global_parameter_vector.gamma.clone();
        response_msg.GlobalPassCount = this.GlobalMaxPassCount;

        this.SendParameterUpdateMessage(response_msg);

    }

    public void AveragePVec_GenerateGlobalUpdateVector(int denominator) {

        this.global_parameter_vector.AverageAccumulations(denominator);

        // post message back to sender async
        GlobalParameterVectorUpdateMessage response_msg = new GlobalParameterVectorUpdateMessage("",
                this.num_categories, this.FeatureVectorSize);
        response_msg.parameter_vector = this.global_parameter_vector.gamma.clone();
        response_msg.GlobalPassCount = this.GlobalMaxPassCount;

        this.SendParameterUpdateMessage(response_msg);

    }

    public void SendParameterUpdateMessage(GlobalParameterVectorUpdateMessage msg) {

        this.outgoing_parameter_updates.add(msg);

    }

    /**
     * Used to collect gradient updates into buffer
     * 
     * @param incoming_buffer
     */
    private void MergeGradientUpdate(GradientBuffer incoming_buffer) {

        this.global_parameter_vector.Accumulate(incoming_buffer);

    }

    /**
     * this is mostly for debug
     * 
     * @return
     */
    public GlobalParameterVectorUpdateMessage GetNextGlobalUpdateMsgFromQueue() {

        return this.outgoing_parameter_updates.remove(0);

    }

    /**
     * TODO: how does this work differently than the other save method?
     * 
     * 
     * NOTE: This should only be used for durability purposes in checkpointing the
     * workers
     * 
     * 
     */
    public void SaveModelLocally(String outputFile) throws Exception {

        this.polr.SetBeta(this.global_parameter_vector.gamma);

        OutputStream modelOutput = new FileOutputStream(outputFile);
        try {
            polr_modelparams.saveTo(modelOutput);
        } finally {
            Closeables.closeQuietly(modelOutput);
        }

    }

    /**
     * [ needs to be checked ]
     * 
     * NOTE: This should only be used for durability purposes in checkpointing the
     * workers
     * 
     * @param outputFilename
     * @param conf
     * @throws Exception
     */
    public void SaveModelToHDFS(String outputFilename, Configuration conf) throws Exception {

        Path path = new Path(outputFilename);
        FileSystem fs = path.getFileSystem(conf);
        FSDataOutputStream modelHDFSOutput = fs.create(path, true);

        try {
            polr_modelparams.saveTo(modelHDFSOutput);
        } finally {
            modelHDFSOutput.close();
        }

    }

    public void Debug() throws IOException {

        System.out.println("POLRMasterDriver --------------------------- ");

        System.out.println("> Num Categories: " + this.num_categories);
        System.out.println("> FeatureVecSize: " + this.FeatureVectorSize);

        this.polr_modelparams.Debug();

    }

}