com.winvector.logistic.demo.MapReduceLogisticTrain.java Source code

Java tutorial

Introduction

Here is the source code for com.winvector.logistic.demo.MapReduceLogisticTrain.java

Source

package com.winvector.logistic.demo;

import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Random;
import java.util.regex.Pattern;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import com.winvector.Licenses;
import com.winvector.logistic.Formula;
import com.winvector.logistic.Model;
import com.winvector.logistic.SigmoidLossMultinomial;
import com.winvector.logistic.mr.MapRedAccuracy;
import com.winvector.logistic.mr.MapRedFn;
import com.winvector.logistic.mr.MapRedScan;
import com.winvector.logistic.mr.WritableUtils;
import com.winvector.logistic.mr.WritableVariableList;
import com.winvector.opt.def.LinUtil;
import com.winvector.opt.def.VEval;
import com.winvector.opt.def.VectorFn;
import com.winvector.opt.def.VectorOptimizer;
import com.winvector.opt.impl.Newton;
import com.winvector.opt.impl.NormPenalty;
import com.winvector.opt.impl.SumFn;
import com.winvector.util.HBurster;
import com.winvector.util.LineBurster;
import com.winvector.util.SerialUtils;
import com.winvector.variables.VariableEncodings;

/// @SuppressWarnings("deprecation") //  Tool depreciated 0.21.0 a known issue: https://issues.apache.org/jira/browse/MAPREDUCE-2084
public class MapReduceLogisticTrain extends Configured implements Tool {

    @Override
    public int run(final String[] args) throws Exception {
        if (args.length != 3) {
            final Log log = LogFactory.getLog(MapReduceLogisticTrain.class);
            log.info(Licenses.licenses);
            log.fatal("use: MapReduceLogisticTrain trainFile formula resultFile");
            return -1;
        }
        final String trainFileName = args[0]; // file with input data in name=value pairs separated by tabs
        final String formulaStr = args[1]; // name of variable we expect to predict
        final String resultFileName = args[2];
        run(trainFileName, formulaStr, null, resultFileName, 5);
        return 0;
    }

    public double run(final String trainFileName, final String formulaStr, final String weightKey,
            final String resultFileName, final int maxNewtonRounds) throws Exception {
        final Log log = LogFactory.getLog(MapReduceLogisticTrain.class);
        log.info("start");
        final Formula formula = new Formula(formulaStr); // force an early parse error if wrong
        final Random rand = new Random();
        final String tmpPrefix = "TMPLR_" + rand.nextLong();
        final Configuration mrConfig = getConf();
        final Path trainFile = new Path(trainFileName);
        final Path resultFile = new Path(resultFileName);
        final String headerLine = WritableUtils.readFirstLine(mrConfig, trainFile);
        final Pattern sepPattern = Pattern.compile("\t");
        final LineBurster burster = new HBurster(sepPattern, headerLine, false);
        mrConfig.set(MapRedScan.BURSTERSERFIELD, SerialUtils.serializableToString(burster));
        final WritableVariableList lConfig = MapRedScan.initialScan(tmpPrefix, mrConfig, trainFile, formulaStr);
        log.info("formula:\t" + formulaStr + "\n" + lConfig.formatState());
        final VariableEncodings defs = new VariableEncodings(lConfig, true, weightKey);
        //final WritableSigmoidLossBinomial underlying = new WritableSigmoidLossBinomial(defs.dim());
        final SigmoidLossMultinomial underlying = new SigmoidLossMultinomial(defs.dim(), defs.noutcomes());
        final MapRedFn f = new MapRedFn(underlying, lConfig, defs.useIntercept(), tmpPrefix, mrConfig, trainFile);
        final ArrayList<VectorFn> fns = new ArrayList<VectorFn>();
        fns.add(f);
        fns.add(new NormPenalty(f.dim(), 1.0e-5, defs.adaptions));
        final VectorFn sl = new SumFn(fns);
        final VectorOptimizer nwt = new Newton();
        final VEval opt = nwt.maximize(sl, null, maxNewtonRounds);
        log.info("done training");
        log.info("soln vector:\n" + LinUtil.toString(opt.x));
        log.info("soln details:\n" + defs.formatSoln(opt.x));
        {
            final Model model = new Model();
            model.config = defs;
            model.coefs = opt.x;
            model.origFormula = formula;
            log.info("writing " + resultFile);
            final FSDataOutputStream fdo = resultFile.getFileSystem(mrConfig).create(resultFile, true);
            final ObjectOutputStream oos = new ObjectOutputStream(fdo);
            oos.writeObject(model);
            oos.close();
        }
        final MapRedAccuracy sc = new MapRedAccuracy(underlying, lConfig, defs.useIntercept(), tmpPrefix, mrConfig,
                trainFile);
        final long[] trainAccuracy = sc.score(opt.x);
        final double accuracy = trainAccuracy[0] / (double) trainAccuracy[1];
        log.info("train accuracy: " + trainAccuracy[0] + "/" + trainAccuracy[1] + "\t" + accuracy);
        return accuracy;
    }

    // ~/Documents/workspace/hadoop-0.21.0/bin/hadoop jar ~/Documents/workspace/Logistic/WinVectorLogistic.jar logistictrain  ~/Documents/workspace/Logistic/test/com/winvector/logistic/uciCarTrain.tsv "rating ~ buying + maintenance + doors + persons + lug_boot + safety" model.ser
    public static void main(final String[] args) throws Exception {
        final MapReduceLogisticTrain mrl = new MapReduceLogisticTrain();
        final int code = ToolRunner.run(null, mrl, args);
        if (code != 0) {
            final Log log = LogFactory.getLog(MapReduceLogisticTrain.class);
            log.fatal("MapReduceLogistic error, return code: " + code);
            ToolRunner.printGenericCommandUsage(System.out);
        }
    }

}