org.apache.mahout.regression.penalizedlinear.PenalizedLinearDriver.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.regression.penalizedlinear.PenalizedLinearDriver.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.regression.penalizedlinear;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

/**
 * Run the penalized linear regression with cross validation.
 * The input file should be Mahout sequence file with VectorWritable.
 * In each line, the first element of VectorWritable is response; the rest are predictors
 */
public class PenalizedLinearDriver extends AbstractJob {

    private static final Logger log = LoggerFactory.getLogger(PenalizedLinearDriver.class);
    private static final int formatWidth = 8;
    private PenalizedLinearParameter parameter;
    private String input;
    private String output;

    private class PenalizedLinearParameter {
        private int numOfCV;
        private float alpha;
        private String lambda;
        private boolean intercept;

        public int getNumOfCV() {
            return numOfCV;
        }

        public void setNumOfCV(int numOfCV) {
            this.numOfCV = numOfCV;
        }

        public float getAlpha() {
            return alpha;
        }

        public void setAlpha(float alpha) {
            this.alpha = alpha;
        }

        public String getLambda() {
            return lambda;
        }

        public void setLambda(String lambda) {
            this.lambda = lambda;
        }

        public boolean isIntercept() {
            return intercept;
        }

        public void setIntercept(boolean intercept) {
            this.intercept = intercept;
        }
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run(new Configuration(), new PenalizedLinearDriver(), args);
    }

    @Override
    public int run(String[] args) throws Exception {
        if (parseArgs(args)) {
            buildRegressionModelMR(parameter, new Path(input), new Path(output));
            PenalizedLinearSolver solver = new PenalizedLinearSolver();
            solver.setAlpha(parameter.alpha);
            solver.setIntercept(parameter.intercept);
            solver.setLambdaString(parameter.lambda);
            solver.initSolver(new Path(output), getConf());
            solver.regularizePath(solver.getLambda());
            printInfo(parameter, solver, "path");

            solver = new PenalizedLinearSolver();
            solver.setAlpha(parameter.alpha);
            solver.setIntercept(parameter.intercept);
            solver.setLambdaString(parameter.lambda);
            solver.initSolver(new Path(output), getConf());
            solver.crossValidate();
            printInfo(parameter, solver, "CV");
        }
        return 0;
    }

    private void printInfo(PenalizedLinearParameter parameter, PenalizedLinearSolver solver, String method) {
        if (method.equals("path")) {
            PenalizedLinearSolver.Coefficients[] coefficients = solver.getCoefficients();
            System.out.println("The training path: lambda    beta0    beta");
            double[] lambdas = solver.getLambda();
            for (int i = 0; i < coefficients.length; ++i) {
                StringBuilder line = new StringBuilder(
                        String.format("%" + Integer.toString(formatWidth) + ".5f", lambdas[i]));
                line.append(" ")
                        .append(String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients[i].beta0));
                for (int j = 0; j < coefficients[i].beta.length; ++j) {
                    line.append(" ").append(
                            String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients[i].beta[j]));
                }
                System.out.println(line);
            }
        } else {
            PenalizedLinearSolver.Coefficients coefficients = solver.getCoefficients()[0];
            double[] trainError = solver.getTrainError();
            double[] testError = solver.getTestError();
            double[] lambdas = solver.getLambda();

            System.out.println("Training and Test Error: lambda    training    testing");
            for (int i = 0; i < lambdas.length; ++i) {
                StringBuilder line = new StringBuilder(
                        String.format("%" + Integer.toString(formatWidth) + ".5f", lambdas[i]) + " "
                                + String.format("%" + Integer.toString(formatWidth) + ".5f", trainError[i]) + " "
                                + String.format("%" + Integer.toString(formatWidth) + ".5f", testError[i]));
                System.out.println(line);
            }

            StringBuilder model = new StringBuilder("model: ");
            StringBuilder coefficientString = new StringBuilder();
            if (parameter.intercept) {
                model.append("beta0    beta");
                coefficientString
                        .append(String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients.beta0));
                for (int i = 0; i < coefficients.beta.length; ++i) {
                    coefficientString.append(" ").append(
                            String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients.beta[i]));
                }
            } else {
                model.append("beta");
                for (int i = 0; i < coefficients.beta.length; ++i) {
                    coefficientString.append(" ").append(
                            String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients.beta[i]));
                }
            }
            System.out.println(model);
            System.out.println(coefficientString);

            System.out.println("Optimal lambda is: "
                    + String.format("%" + Integer.toString(formatWidth) + ".5f", solver.getOptLambda()[0]));
        }
    }

    private void buildRegressionModelMR(PenalizedLinearParameter parameter, Path input, Path output)
            throws IOException, InterruptedException, ClassNotFoundException {

        Job job = prepareJob(input, output, SequenceFileInputFormat.class, PenalizedLinearMapper.class, Text.class,
                VectorWritable.class, PenalizedLinearReducer.class, Text.class, VectorWritable.class,
                SequenceFileOutputFormat.class);
        job.setJobName("Penalized Linear Regression Driver running over input: " + input);
        job.setNumReduceTasks(1);
        job.setJarByClass(PenalizedLinearDriver.class);

        Configuration conf = job.getConfiguration();
        conf.setInt(PenalizedLinearKeySet.NUM_CV, parameter.getNumOfCV());
        conf.setFloat(PenalizedLinearKeySet.ALPHA, parameter.getAlpha());
        conf.set(PenalizedLinearKeySet.LAMBDA, parameter.getLambda());
        conf.setBoolean(PenalizedLinearKeySet.INTERCEPT, parameter.isIntercept());

        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("Penalized Linear Regression Job failed processing " + input);
        }
    }

    private boolean parseArgs(String[] args) {
        DefaultOptionBuilder builder = new DefaultOptionBuilder();

        Option help = builder.withLongName("help").withDescription("print this list").create();

        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        Option inputFile = builder.withLongName("input").withRequired(true)
                .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
                .withDescription(
                        "where to get training data (Mahout sequence file of VectorWritable); in each line, the first element is response; rest are predictors.")
                .create();

        Option outputFile = builder.withLongName("output").withRequired(true)
                .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
                .withDescription("where to get results").create();

        Option lambda = builder.withLongName("lambda")
                .withArgument(argumentBuilder.withName("lambda").withDefault("0").withMinimum(1).create())
                .withDescription("an increasing positive sequence of penalty coefficient, "
                        + "with length n >= 0; if lambda is not specified, the sequence is chosen by algorithm.")
                .create();

        Option alpha = builder.withLongName("alpha")
                .withArgument(
                        argumentBuilder.withName("alpha").withDefault("1").withMinimum(1).withMaximum(1).create())
                .withDescription("the elastic-net coefficient with default value 1 (LASSO)").create();

        Option bias = builder.withLongName("bias").withDescription("include a bias term").create();

        Option numOfCV = builder.withLongName("numOfCV")
                .withArgument(
                        argumentBuilder.withName("numOfCV").withDefault("5").withMinimum(0).withMaximum(1).create())
                .withDescription("number of cross validation, the rule of thumb is 5 or 10").create();

        Group normalArgs = new GroupBuilder().withOption(help).withOption(inputFile).withOption(outputFile)
                .withOption(lambda).withOption(alpha).withOption(bias).withOption(numOfCV).create();

        Parser parser = new Parser();
        parser.setHelpOption(help);
        parser.setHelpTrigger("--help");
        parser.setGroup(normalArgs);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
        CommandLine cmdLine = parser.parseAndHelp(args);
        if (cmdLine == null) {
            return false;
        }

        parameter = new PenalizedLinearParameter();
        parameter.setNumOfCV(Integer.parseInt((String) cmdLine.getValue(numOfCV)));
        parameter.setAlpha(Float.parseFloat((String) cmdLine.getValue(alpha)));
        parameter.setIntercept(cmdLine.hasOption(bias));

        if (!processLambda(parameter, cmdLine, lambda) || parameter.alpha < 0.0 || parameter.alpha > 1.0
                || parameter.numOfCV < 1 || parameter.numOfCV > 20) {
            log.error(
                    "please make sure the lambda sequence is positive and increasing, and 0.0 <= alphaValue <= 1.0 and 1 <= numOfCV <= 20");
            return false;
        }

        input = (String) cmdLine.getValue(inputFile);
        output = (String) cmdLine.getValue(outputFile);

        return true;
    }

    boolean processLambda(PenalizedLinearParameter parameter, CommandLine cmdLine, Option lambda) {
        StringBuilder lambdaSeq = new StringBuilder();
        double previous = Double.NEGATIVE_INFINITY;
        if (cmdLine.hasOption(lambda)) {
            for (Object x : cmdLine.getValues(lambda)) {
                double number = Double.parseDouble(x.toString());
                if (previous >= number || number < 0) {
                    return false;
                }
                lambdaSeq.append(x.toString()).append(",");
                previous = number;
            }
            parameter.setLambda(lambdaSeq.substring(0, lambdaSeq.length() - 1));
            return true;
        } else {
            parameter.setLambda("");
            return true;
        }
    }
}