Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.regression.penalizedlinear; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.commons.cli2.util.HelpFormatter; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.math.VectorWritable; import org.apache.mahout.regression.feature.extractor.FeatureExtractUtility; import org.apache.mahout.regression.feature.extractor.FeatureExtractorKeySet; import org.apache.mahout.regression.feature.extractor.FeatureExtractorMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.util.Set; import java.util.TreeSet; public class LinearCrossValidation extends AbstractJob { private static final String DIRECTORY_CONTAINING_CONVERTED_INPUT = "data"; private static final int formatWidth = 8; private class LinearCrossValidationParameter { public int numOfCV; public String dependent; public String independent; public String interaction; public float alpha; public String lambda; public boolean intercept; } private LinearCrossValidationParameter parameter; private String featureNames; private String separator; private String input; private String output; private PenalizedLinearSolver solver; private LinearCrossValidation() { } private static final Logger log = LoggerFactory.getLogger(LinearCrossValidation.class); public static void main(String[] args) throws Exception { ToolRunner.run(new Configuration(), new LinearCrossValidation(), args); } private boolean validateParameter(LinearCrossValidationParameter parameter, String featureNames, String separator) { String pattern = FeatureExtractUtility.SeparatorToPattern(separator); if (parameter.dependent.equals("") && parameter.interaction.equals("")) { log.error("both of the dependent and interaction are empty!"); return false; } else { String[] features = featureNames.trim().split(pattern); Set<String> featureSet = new TreeSet<String>(); for (int i = 0; i < features.length; ++i) { featureSet.add(features[i]); } if (!parameter.independent.equals("")) { String[] independent = parameter.independent.split(","); for (int i = 0; i < independent.length; ++i) { if (!featureSet.contains(independent[i])) { return false; } } } if (!parameter.interaction.equals("")) { String[] interaction = parameter.interaction.split(","); for (int i = 0; i < interaction.length; ++i) { if ((!featureSet.contains(interaction[i].split(":")[0])) || (!featureSet.contains(interaction[i].split(":")[1]))) { return false; } } } return featureSet.contains(parameter.dependent); } } @Override public int run(String[] args) throws Exception { if (parseArgs(args)) { String[] inputPath = input.split("/"); String suffix = inputPath[inputPath.length - 1].split("\\.")[1]; separator = FeatureExtractUtility.ExtensionToSeparator(suffix); FileSystem fs = FileSystem.get(getConf()); BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(new Path(input)))); featureNames = br.readLine(); br.close(); if (!validateParameter(parameter, featureNames, separator)) { log.error("feature names provided are not correct!"); return 1; } run(); } return 0; } private void run() throws Exception { runFeatureExtractor(); runPenalizedLinear(); } private void runPenalizedLinear() throws IOException, InterruptedException, ClassNotFoundException { Configuration conf = getConf(); conf.setInt(PenalizedLinearKeySet.NUM_CV, parameter.numOfCV); conf.setFloat(PenalizedLinearKeySet.ALPHA, parameter.alpha); conf.set(PenalizedLinearKeySet.LAMBDA, parameter.lambda); conf.setBoolean(PenalizedLinearKeySet.INTERCEPT, parameter.intercept); Job job = new Job(conf, "Penalized Linear Regression Driver running over input: " + input); job.setInputFormatClass(SequenceFileInputFormat.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); job.setMapperClass(PenalizedLinearMapper.class); job.setMapOutputKeyClass(Text.class); job.setMapOutputValueClass(VectorWritable.class); job.setReducerClass(PenalizedLinearReducer.class); job.setOutputKeyClass(Text.class); job.setOutputValueClass(VectorWritable.class); job.setCombinerClass(PenalizedLinearReducer.class); job.setNumReduceTasks(1); job.setJarByClass(LinearRegularizePath.class); FileInputFormat.addInputPath(job, new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT)); FileOutputFormat.setOutputPath(job, new Path(output, "output")); if (!job.waitForCompletion(true)) { throw new InterruptedException("Penalized Linear Regression Job failed processing " + input); } solver = new PenalizedLinearSolver(); solver.setAlpha(parameter.alpha); solver.setIntercept(parameter.intercept); solver.setLambdaString(parameter.lambda); solver.initSolver(new Path(output, "output"), getConf()); solver.crossValidate(); printInfo(parameter, solver); } private void printInfo(LinearCrossValidationParameter parameter, PenalizedLinearSolver solver) throws IOException { PenalizedLinearSolver.Coefficients coefficients = solver.getCoefficients()[0]; double[] lambdas = solver.getLambda(); String model = "model:"; model += " " + parameter.dependent + " ~"; if (parameter.intercept) { model += " " + String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients.beta0) + " * intercept"; } else { model += " " + "0"; } int index = 0; if (!parameter.independent.equals("")) { String[] independent = parameter.independent.split(","); for (int i = 0; i < independent.length; ++i) { model += " + " + String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients.beta[index++]) + " * " + independent[i]; } } if (!parameter.interaction.equals("")) { String[] interaction = parameter.interaction.split(","); for (int i = 0; i < interaction.length; ++i) { model += " + " + String.format("%" + Integer.toString(formatWidth) + ".5f", coefficients.beta[index++]) + " * " + interaction[i]; } } System.out.println(); System.out.println(model); System.out.println("Optimal lambda is: " + String.format("%" + Integer.toString(formatWidth) + ".5f", solver.getOptLambda()[0])); System.out.println("The training and (test) errors are in file: " + output + "/trainingTestingError.txt"); double[] trainError = solver.getTrainError(); double[] testError = solver.getTestError(); FileSystem fs = FileSystem.get(getConf()); BufferedWriter br = new BufferedWriter( new OutputStreamWriter(fs.create(new Path(output, "trainingTestingError.txt"), true))); for (int i = 0; i < lambdas.length; ++i) { String line = "" + lambdas[i] + " " + trainError[i] + " " + testError[i]; br.write(line + "\n"); } br.close(); } private void runFeatureExtractor() throws IOException, InterruptedException, ClassNotFoundException { Configuration conf = new Configuration(); conf.set("vector.implementation.class.name", "org.apache.mahout.math.RandomAccessSparseVector"); conf.set(FeatureExtractorKeySet.FEATURE_NAMES, featureNames); conf.set(FeatureExtractorKeySet.SELECTED_DEPENDENT, parameter.dependent); conf.set(FeatureExtractorKeySet.SELECTED_INDEPENDENT, parameter.independent); conf.set(FeatureExtractorKeySet.SELECTED_INTERACTION, parameter.interaction); conf.set(FeatureExtractorKeySet.SEPARATOR, separator); Job job = new Job(conf, "Input Driver running over input: " + input); job.setOutputKeyClass(Text.class); job.setOutputValueClass(VectorWritable.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); job.setMapperClass(FeatureExtractorMapper.class); job.setNumReduceTasks(0); job.setJarByClass(LinearRegularizePath.class); FileInputFormat.addInputPath(job, new Path(input)); FileOutputFormat.setOutputPath(job, new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT)); boolean succeeded = job.waitForCompletion(true); if (!succeeded) { throw new IllegalStateException("Job failed!"); } } private boolean parseArgs(String[] args) { DefaultOptionBuilder builder = new DefaultOptionBuilder(); Option help = builder.withLongName("help").withDescription("print this list").create(); ArgumentBuilder argumentBuilder = new ArgumentBuilder(); Option inputFile = builder.withLongName("input").withRequired(true) .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) .withDescription("where to get training data (CSV or white-spaced TEXT file)").create(); Option outputFile = builder.withLongName("output").withRequired(true) .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) .withDescription("where to get results").create(); Option dependent = builder.withLongName("dependent").withRequired(true) .withArgument(argumentBuilder.withName("dependent").withMinimum(1).withMaximum(1).create()) .withDescription("the dependent features").create(); Option independent = builder.withLongName("independent").withRequired(true) .withArgument(argumentBuilder.withName("independent").create()) .withDescription("the independent features").create(); Option interaction = builder.withLongName("interaction").withRequired(true) .withArgument(argumentBuilder.withName("interaction").withMinimum(0).create()) .withDescription( "the interactions of features, the format is: feature1:feature2 (identical features are OK)") .create(); Option bias = builder.withLongName("bias").withDescription("include a bias term").create(); Option lambda = builder.withLongName("lambda") .withArgument(argumentBuilder.withName("lambda").withDefault("0").withMinimum(1).create()) .withDescription("an increasing positive sequence of penalty coefficient, " + "with length n >= 0; if lambda is not specified, the sequence is chosen by algorithm.") .create(); Option alpha = builder.withLongName("alpha") .withArgument( argumentBuilder.withName("alpha").withDefault("1").withMinimum(1).withMaximum(1).create()) .withDescription("the elastic-net coefficient with default value 1 (LASSO)").create(); Option numOfCV = builder.withLongName("numOfCV") .withArgument( argumentBuilder.withName("numOfCV").withDefault("5").withMinimum(0).withMaximum(1).create()) .withDescription("number of cross validation, the rule of thumb is 5 or 10").create(); Group normalArgs = new GroupBuilder().withOption(help).withOption(inputFile).withOption(outputFile) .withOption(dependent).withOption(independent).withOption(interaction).withOption(bias) .withOption(lambda).withOption(alpha).withOption(numOfCV).create(); Parser parser = new Parser(); parser.setHelpOption(help); parser.setHelpTrigger("--help"); parser.setGroup(normalArgs); parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); CommandLine cmdLine = parser.parseAndHelp(args); if (cmdLine == null) { return false; } parameter = new LinearCrossValidationParameter(); parameter.numOfCV = Integer.parseInt((String) cmdLine.getValue(numOfCV)); parameter.alpha = Float.parseFloat((String) cmdLine.getValue(alpha)); parameter.intercept = cmdLine.hasOption(bias); parameter.dependent = (String) cmdLine.getValue(dependent); String independentString = ""; for (Object x : cmdLine.getValues(independent)) { independentString += x.toString() + ","; } parameter.independent = independentString.substring(0, Math.max(independentString.length() - 1, 0)); String interactionString = ""; for (Object x : cmdLine.getValues(interaction)) { interactionString += x.toString() + ","; } parameter.interaction = interactionString.substring(0, Math.max(interactionString.length() - 1, 0)); if (!processLambda(parameter, cmdLine, lambda) || parameter.alpha < 0.0 || parameter.alpha > 1.0 || parameter.numOfCV < 1 || parameter.numOfCV > 20) { log.error( "please make sure the lambda sequence is positive and increasing, and 0.0 <= alphaValue <= 1.0 and 1 <= numofCV <= 20"); return false; } input = (String) cmdLine.getValue(inputFile); output = (String) cmdLine.getValue(outputFile); return true; } private boolean processLambda(LinearCrossValidationParameter parameter, CommandLine cmdLine, Option lambda) { String lambdaSeq = ""; double previous = Double.NEGATIVE_INFINITY; if (cmdLine.hasOption(lambda)) { for (Object x : cmdLine.getValues(lambda)) { double number = Double.parseDouble(x.toString()); if (previous >= number || number < 0) { return false; } lambdaSeq += x.toString() + ","; previous = number; } parameter.lambda = lambdaSeq.substring(0, lambdaSeq.length() - 1); return true; } else { parameter.lambda = ""; return true; } } }