List of usage examples for org.apache.commons.cli2.builder DefaultOptionBuilder DefaultOptionBuilder
public DefaultOptionBuilder()
From source file:org.apache.mahout.classifier.Classify.java
public static void main(String[] args) throws Exception { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option pathOpt = obuilder.withLongName("path").withRequired(true) .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("The local file system path").withShortName("m").create(); Option classifyOpt = obuilder.withLongName("classify").withRequired(true) .withArgument(abuilder.withName("classify").withMinimum(1).withMaximum(1).create()) .withDescription("The doc to classify").withShortName("").create(); Option encodingOpt = obuilder.withLongName("encoding").withRequired(true) .withArgument(abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()) .withDescription("The file encoding. Default: UTF-8").withShortName("e").create(); Option analyzerOpt = obuilder.withLongName("analyzer").withRequired(true) .withArgument(abuilder.withName("analyzer").withMinimum(1).withMaximum(1).create()) .withDescription("The Analyzer to use").withShortName("a").create(); Option defaultCatOpt = obuilder.withLongName("defaultCat").withRequired(true) .withArgument(abuilder.withName("defaultCat").withMinimum(1).withMaximum(1).create()) .withDescription("The default category").withShortName("d").create(); Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true) .withArgument(abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()) .withDescription("Size of the n-gram").withShortName("ng").create(); Option typeOpt = obuilder.withLongName("classifierType").withRequired(true) .withArgument(abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()) .withDescription("Type of classifier").withShortName("type").create(); Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(true) .withArgument(abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()) .withDescription("Location of model: hdfs").withShortName("source").create(); Group options = gbuilder.withName("Options").withOption(pathOpt).withOption(classifyOpt) .withOption(encodingOpt).withOption(analyzerOpt).withOption(defaultCatOpt).withOption(gramSizeOpt) .withOption(typeOpt).withOption(dataSourceOpt).create(); Parser parser = new Parser(); parser.setGroup(options);//from ww w . j ava2s.c o m CommandLine cmdLine = parser.parse(args); int gramSize = 1; if (cmdLine.hasOption(gramSizeOpt)) { gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)); } BayesParameters params = new BayesParameters(); params.setGramSize(gramSize); String modelBasePath = (String) cmdLine.getValue(pathOpt); params.setBasePath(modelBasePath); log.info("Loading model from: {}", params.print()); Algorithm algorithm; Datastore datastore; String classifierType = (String) cmdLine.getValue(typeOpt); String dataSource = (String) cmdLine.getValue(dataSourceOpt); if ("hdfs".equals(dataSource)) { if ("bayes".equalsIgnoreCase(classifierType)) { log.info("Using Bayes Classifier"); algorithm = new BayesAlgorithm(); datastore = new InMemoryBayesDatastore(params); } else if ("cbayes".equalsIgnoreCase(classifierType)) { log.info("Using Complementary Bayes Classifier"); algorithm = new CBayesAlgorithm(); datastore = new InMemoryBayesDatastore(params); } else { throw new IllegalArgumentException("Unrecognized classifier type: " + classifierType); } } else { throw new IllegalArgumentException("Unrecognized dataSource type: " + dataSource); } ClassifierContext classifier = new ClassifierContext(algorithm, datastore); classifier.initialize(); String defaultCat = "unknown"; if (cmdLine.hasOption(defaultCatOpt)) { defaultCat = (String) cmdLine.getValue(defaultCatOpt); } File docPath = new File((String) cmdLine.getValue(classifyOpt)); String encoding = "UTF-8"; if (cmdLine.hasOption(encodingOpt)) { encoding = (String) cmdLine.getValue(encodingOpt); } Analyzer analyzer = null; if (cmdLine.hasOption(analyzerOpt)) { analyzer = ClassUtils.instantiateAs((String) cmdLine.getValue(analyzerOpt), Analyzer.class); } if (analyzer == null) { analyzer = new StandardAnalyzer(Version.LUCENE_31); } log.info("Converting input document to proper format"); String[] document = BayesFileFormatter.readerToDocument(analyzer, Files.newReader(docPath, Charset.forName(encoding))); StringBuilder line = new StringBuilder(); for (String token : document) { line.append(token).append(' '); } List<String> doc = new NGrams(line.toString(), gramSize).generateNGramsWithoutLabel(); log.info("Done converting"); log.info("Classifying document: {}", docPath); ClassifierResult category = classifier.classifyDocument(doc.toArray(new String[doc.size()]), defaultCat); log.info("Category for {} is {}", docPath, category); }
From source file:org.apache.mahout.classifier.df.BreimanExample.java
@Override public int run(String[] args) throws IOException { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true) .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("Data path").create(); Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true) .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()) .withDescription("Dataset path").create(); Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true) .withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()) .withDescription("Number of trees to grow, each iteration").create(); Option nbItersOpt = obuilder.withLongName("iterations").withShortName("i").withRequired(true) .withArgument(abuilder.withName("numIterations").withMinimum(1).withMaximum(1).create()) .withDescription("Number of times to repeat the test").create(); Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") .create();// w w w. j ava2 s .c om Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(nbItersOpt) .withOption(nbtreesOpt).withOption(helpOpt).create(); Path dataPath; Path datasetPath; int nbTrees; int nbIterations; try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption("help")) { CommandLineUtil.printHelp(group); return -1; } String dataName = cmdLine.getValue(dataOpt).toString(); String datasetName = cmdLine.getValue(datasetOpt).toString(); nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString()); nbIterations = Integer.parseInt(cmdLine.getValue(nbItersOpt).toString()); dataPath = new Path(dataName); datasetPath = new Path(datasetName); } catch (OptionException e) { log.error("Error while parsing options", e); CommandLineUtil.printHelp(group); return -1; } // load the data FileSystem fs = dataPath.getFileSystem(new Configuration()); Dataset dataset = Dataset.load(getConf(), datasetPath); Data data = DataLoader.loadData(dataset, fs, dataPath); // take m to be the first integer less than log2(M) + 1, where M is the // number of inputs int m = (int) Math.floor(FastMath.log(2.0, data.getDataset().nbAttributes()) + 1); Random rng = RandomUtils.getRandom(); for (int iteration = 0; iteration < nbIterations; iteration++) { log.info("Iteration {}", iteration); runIteration(rng, data, m, nbTrees); } log.info("********************************************"); log.info("Random Input Test Error : {}", sumTestErrM / nbIterations); log.info("Single Input Test Error : {}", sumTestErrOne / nbIterations); log.info("Mean Random Input Time : {}", DFUtils.elapsedTime(sumTimeM / nbIterations)); log.info("Mean Single Input Time : {}", DFUtils.elapsedTime(sumTimeOne / nbIterations)); log.info("Mean Random Input Num Nodes : {}", numNodesM / nbIterations); log.info("Mean Single Input Num Nodes : {}", numNodesOne / nbIterations); return 0; }
From source file:org.apache.mahout.classifier.df.mapreduce.Resampling.java
public int run(String[] args) throws Exception, ClassNotFoundException, InterruptedException { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true) .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("Data path").create(); Option dataPreprocessingOpt = obuilder.withLongName("dataPreprocessing").withShortName("dp") .withRequired(true).withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("Data Preprocessing path").create(); Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true) .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()) .withDescription("Dataset path").create(); Option timeOpt = obuilder.withLongName("time").withShortName("tm").withRequired(false) .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("Time path").create(); Option helpOpt = obuilder.withLongName("help").withShortName("h").withDescription("Print out help") .create();/* w w w . jav a2 s . c o m*/ Option resamplingOpt = obuilder.withLongName("resampling").withShortName("rs").withRequired(true) .withArgument(abuilder.withName("resampling").withMinimum(1).withMaximum(1).create()) .withDescription( "The resampling technique (oversampling (overs), undersampling (unders) or SMOTE (smote))") .create(); Option nbpartitionsOpt = obuilder.withLongName("nbpartitions").withShortName("p").withRequired(true) .withArgument(abuilder.withName("nbpartitions").withMinimum(1).withMaximum(1).create()) .withDescription("Number of partitions").create(); Option nposOpt = obuilder.withLongName("npos").withShortName("npos").withRequired(true) .withArgument(abuilder.withName("npos").withMinimum(1).withMaximum(1).create()) .withDescription("Number of instances of the positive class").create(); Option nnegOpt = obuilder.withLongName("nneg").withShortName("nneg").withRequired(true) .withArgument(abuilder.withName("nneg").withMinimum(1).withMaximum(1).create()) .withDescription("Number of instances of the negative class").create(); Option negclassOpt = obuilder.withLongName("negclass").withShortName("negclass").withRequired(true) .withArgument(abuilder.withName("negclass").withMinimum(1).withMaximum(1).create()) .withDescription("Name of the negative class").create(); Option posclassOpt = obuilder.withLongName("posclass").withShortName("posclass").withRequired(true) .withArgument(abuilder.withName("posclass").withMinimum(1).withMaximum(1).create()) .withDescription("Name of the positive class").create(); Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(timeOpt) .withOption(helpOpt).withOption(resamplingOpt).withOption(dataPreprocessingOpt) .withOption(nbpartitionsOpt).withOption(nposOpt).withOption(nnegOpt).withOption(negclassOpt) .withOption(posclassOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption("help")) { CommandLineUtil.printHelp(group); return -1; } dataName = cmdLine.getValue(dataOpt).toString(); String datasetName = cmdLine.getValue(datasetOpt).toString(); dataPreprocessing = cmdLine.getValue(dataPreprocessingOpt).toString(); String resampling = cmdLine.getValue(resamplingOpt).toString(); partitions = Integer.parseInt(cmdLine.getValue(nbpartitionsOpt).toString()); npos = Integer.parseInt(cmdLine.getValue(nposOpt).toString()); nneg = Integer.parseInt(cmdLine.getValue(nnegOpt).toString()); negclass = cmdLine.getValue(negclassOpt).toString(); posclass = cmdLine.getValue(posclassOpt).toString(); if (resampling.equalsIgnoreCase("overs")) { withOversampling = true; } else if (resampling.equalsIgnoreCase("unders")) { withUndersampling = true; } else if (resampling.equalsIgnoreCase("smote")) { withSmote = true; } if (cmdLine.hasOption(timeOpt)) { preprocessingTimeIsStored = true; timeName = cmdLine.getValue(timeOpt).toString(); } if (log.isDebugEnabled()) { log.debug("data : {}", dataName); log.debug("dataset : {}", datasetName); log.debug("time : {}", timeName); log.debug("Oversampling : {}", withOversampling); log.debug("Undersampling : {}", withUndersampling); log.debug("SMOTE : {}", withSmote); } dataPath = new Path(dataName); datasetPath = new Path(datasetName); dataPreprocessingPath = new Path(dataPreprocessing); if (preprocessingTimeIsStored) timePath = new Path(timeName); } catch (OptionException e) { log.error("Exception", e); CommandLineUtil.printHelp(group); return -1; } if (withOversampling) { overSampling(); } else if (withUndersampling) { underSampling(); } else if (withSmote) { smote(); } return 0; }
From source file:org.apache.mahout.classifier.df.tools.ForestVisualizer.java
public static void main(String[] args) { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true) .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()) .withDescription("Dataset path").create(); Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true) .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("Path to the Decision Forest").create(); Option attrNamesOpt = obuilder.withLongName("names").withShortName("n").withRequired(false) .withArgument(abuilder.withName("names").withMinimum(1).create()) .withDescription("Optional, Attribute names").create(); Option helpOpt = obuilder.withLongName("help").withShortName("h").withDescription("Print out help") .create();/* w ww . j av a2 s . c o m*/ Group group = gbuilder.withName("Options").withOption(datasetOpt).withOption(modelOpt) .withOption(attrNamesOpt).withOption(helpOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption("help")) { CommandLineUtil.printHelp(group); return; } String datasetName = cmdLine.getValue(datasetOpt).toString(); String modelName = cmdLine.getValue(modelOpt).toString(); String[] attrNames = null; if (cmdLine.hasOption(attrNamesOpt)) { Collection<String> names = (Collection<String>) cmdLine.getValues(attrNamesOpt); if (!names.isEmpty()) { attrNames = new String[names.size()]; names.toArray(attrNames); } } print(modelName, datasetName, attrNames); } catch (Exception e) { log.error("Exception", e); CommandLineUtil.printHelp(group); } }
From source file:org.apache.mahout.classifier.df.tools.Frequencies.java
@Override public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true) .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("Data path").create(); Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true) .withArgument(abuilder.withName("path").withMinimum(1).create()).withDescription("dataset path") .create();/*from w w w .ja v a 2 s . c o m*/ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") .create(); Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(helpOpt) .create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption(helpOpt)) { CommandLineUtil.printHelp(group); return 0; } String dataPath = cmdLine.getValue(dataOpt).toString(); String datasetPath = cmdLine.getValue(datasetOpt).toString(); log.debug("Data path : {}", dataPath); log.debug("Dataset path : {}", datasetPath); runTool(dataPath, datasetPath); } catch (OptionException e) { log.warn(e.toString(), e); CommandLineUtil.printHelp(group); } return 0; }
From source file:org.apache.mahout.classifier.df.tools.UDistrib.java
/** * Launch the uniform distribution tool. Requires the following command line arguments:<br> * //from www . j a va 2s . com * data : data path dataset : dataset path numpartitions : num partitions output : output path * * @throws java.io.IOException */ public static void main(String[] args) throws IOException { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true) .withArgument(abuilder.withName("data").withMinimum(1).withMaximum(1).create()) .withDescription("Data path").create(); Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true) .withArgument(abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset path") .create(); Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true) .withArgument(abuilder.withName("output").withMinimum(1).withMaximum(1).create()) .withDescription("Path to generated files").create(); Option partitionsOpt = obuilder.withLongName("numpartitions").withShortName("p").withRequired(true) .withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()) .withDescription("Number of partitions to create").create(); Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") .create(); Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption(datasetOpt) .withOption(partitionsOpt).withOption(helpOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption(helpOpt)) { CommandLineUtil.printHelp(group); return; } String data = cmdLine.getValue(dataOpt).toString(); String dataset = cmdLine.getValue(datasetOpt).toString(); int numPartitions = Integer.parseInt(cmdLine.getValue(partitionsOpt).toString()); String output = cmdLine.getValue(outputOpt).toString(); runTool(data, dataset, output, numPartitions); } catch (OptionException e) { log.warn(e.toString(), e); CommandLineUtil.printHelp(group); } }
From source file:org.apache.mahout.classifier.mlp.RunMultilayerPerceptron.java
/** * Parse the arguments./*from w w w . j ava2 s. c o m*/ * * @param args The input arguments. * @param parameters The parameters need to be filled. * @return true or false * @throws Exception */ private static boolean parseArgs(String[] args, Parameters parameters) throws Exception { // build the options log.info("Validate and parse arguments..."); DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); GroupBuilder groupBuilder = new GroupBuilder(); ArgumentBuilder argumentBuilder = new ArgumentBuilder(); Option inputFileFormatOption = optionBuilder .withLongName("format").withShortName("f").withArgument(argumentBuilder.withName("file type") .withDefault("csv").withMinimum(1).withMaximum(1).create()) .withDescription("type of input file, currently support 'csv'").create(); List<Integer> columnRangeDefault = Lists.newArrayList(); columnRangeDefault.add(0); columnRangeDefault.add(Integer.MAX_VALUE); Option skipHeaderOption = optionBuilder.withLongName("skipHeader").withShortName("sh").withRequired(false) .withDescription("whether to skip the first row of the input file").create(); Option inputColumnRangeOption = optionBuilder.withLongName("columnRange").withShortName("cr") .withDescription("the column range of the input file, start from 0").withArgument(argumentBuilder .withName("range").withMinimum(2).withMaximum(2).withDefaults(columnRangeDefault).create()) .create(); Group inputFileTypeGroup = groupBuilder.withOption(skipHeaderOption).withOption(inputColumnRangeOption) .withOption(inputFileFormatOption).create(); Option inputOption = optionBuilder.withLongName("input").withShortName("i").withRequired(true) .withArgument(argumentBuilder.withName("file path").withMinimum(1).withMaximum(1).create()) .withDescription("the file path of unlabelled dataset").withChildren(inputFileTypeGroup).create(); Option modelOption = optionBuilder.withLongName("model").withShortName("mo").withRequired(true) .withArgument(argumentBuilder.withName("model file").withMinimum(1).withMaximum(1).create()) .withDescription("the file path of the model").create(); Option labelsOption = optionBuilder.withLongName("labels").withShortName("labels") .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create()) .withDescription("an ordered list of label names").create(); Group labelsGroup = groupBuilder.withOption(labelsOption).create(); Option outputOption = optionBuilder.withLongName("output").withShortName("o").withRequired(true) .withArgument( argumentBuilder.withConsumeRemaining("file path").withMinimum(1).withMaximum(1).create()) .withDescription("the file path of labelled results").withChildren(labelsGroup).create(); // parse the input Parser parser = new Parser(); Group normalOption = groupBuilder.withOption(inputOption).withOption(modelOption).withOption(outputOption) .create(); parser.setGroup(normalOption); CommandLine commandLine = parser.parseAndHelp(args); if (commandLine == null) { return false; } // obtain the arguments parameters.inputFilePathStr = TrainMultilayerPerceptron.getString(commandLine, inputOption); parameters.inputFileFormat = TrainMultilayerPerceptron.getString(commandLine, inputFileFormatOption); parameters.skipHeader = commandLine.hasOption(skipHeaderOption); parameters.modelFilePathStr = TrainMultilayerPerceptron.getString(commandLine, modelOption); parameters.outputFilePathStr = TrainMultilayerPerceptron.getString(commandLine, outputOption); List<?> columnRange = commandLine.getValues(inputColumnRangeOption); parameters.columnStart = Integer.parseInt(columnRange.get(0).toString()); parameters.columnEnd = Integer.parseInt(columnRange.get(1).toString()); return true; }
From source file:org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron.java
/** * Parse the input arguments./*w w w . j a va 2s . c o m*/ * * @param args The input arguments * @param parameters The parameters parsed. * @return Whether the input arguments are valid. * @throws Exception */ private static boolean parseArgs(String[] args, Parameters parameters) throws Exception { // build the options log.info("Validate and parse arguments..."); DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); GroupBuilder groupBuilder = new GroupBuilder(); ArgumentBuilder argumentBuilder = new ArgumentBuilder(); // whether skip the first row of the input file Option skipHeaderOption = optionBuilder.withLongName("skipHeader").withShortName("sh").create(); Group skipHeaderGroup = groupBuilder.withOption(skipHeaderOption).create(); Option inputOption = optionBuilder.withLongName("input").withShortName("i").withRequired(true) .withChildren(skipHeaderGroup) .withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1).create()) .withDescription("the file path of training dataset").create(); Option labelsOption = optionBuilder.withLongName("labels").withShortName("labels").withRequired(true) .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create()) .withDescription("label names").create(); Option updateOption = optionBuilder.withLongName("update").withShortName("u") .withDescription("whether to incrementally update model if the model exists").create(); Group modelUpdateGroup = groupBuilder.withOption(updateOption).create(); Option modelOption = optionBuilder.withLongName("model").withShortName("mo").withRequired(true) .withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create()) .withDescription("the path to store the trained model").withChildren(modelUpdateGroup).create(); Option layerSizeOption = optionBuilder.withLongName("layerSize").withShortName("ls").withRequired(true) .withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create()) .withDescription("the size of each layer").create(); Option squashingFunctionOption = optionBuilder.withLongName("squashingFunction").withShortName("sf") .withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1) .withDefault("Sigmoid").create()) .withDescription("the name of squashing function (currently only supports Sigmoid)").create(); Option learningRateOption = optionBuilder.withLongName("learningRate").withShortName("l") .withArgument(argumentBuilder.withName("learning rate").withMaximum(1).withMinimum(1) .withDefault(NeuralNetwork.DEFAULT_LEARNING_RATE).create()) .withDescription("learning rate").create(); Option momemtumOption = optionBuilder.withLongName("momemtumWeight").withShortName("m") .withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1).withMinimum(1) .withDefault(NeuralNetwork.DEFAULT_MOMENTUM_WEIGHT).create()) .withDescription("momemtum weight").create(); Option regularizationOption = optionBuilder.withLongName("regularizationWeight").withShortName("r") .withArgument(argumentBuilder.withName("regularization weight").withMaximum(1).withMinimum(1) .withDefault(NeuralNetwork.DEFAULT_REGULARIZATION_WEIGHT).create()) .withDescription("regularization weight").create(); // parse the input Parser parser = new Parser(); Group normalOptions = groupBuilder.withOption(inputOption).withOption(skipHeaderOption) .withOption(updateOption).withOption(labelsOption).withOption(modelOption) .withOption(layerSizeOption).withOption(squashingFunctionOption).withOption(learningRateOption) .withOption(momemtumOption).withOption(regularizationOption).create(); parser.setGroup(normalOptions); CommandLine commandLine = parser.parseAndHelp(args); if (commandLine == null) { return false; } parameters.learningRate = getDouble(commandLine, learningRateOption); parameters.momemtumWeight = getDouble(commandLine, momemtumOption); parameters.regularizationWeight = getDouble(commandLine, regularizationOption); parameters.inputFilePath = getString(commandLine, inputOption); parameters.skipHeader = commandLine.hasOption(skipHeaderOption); List<String> labelsList = getStringList(commandLine, labelsOption); int currentIndex = 0; for (String label : labelsList) { parameters.labelsIndex.put(label, currentIndex++); } parameters.modelFilePath = getString(commandLine, modelOption); parameters.updateModel = commandLine.hasOption(updateOption); parameters.layerSizeList = getIntegerList(commandLine, layerSizeOption); parameters.squashingFunctionName = getString(commandLine, squashingFunctionOption); System.out.printf( "Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f," + " Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath, parameters.updateModel, Arrays.toString(parameters.layerSizeList.toArray()), parameters.squashingFunctionName, parameters.learningRate, parameters.momemtumWeight, parameters.regularizationWeight); return true; }
From source file:org.apache.mahout.classifier.rbm.test.TestRBMClassifierJob.java
@Override public int run(String[] args) throws Exception { addInputOption();//from ww w . j a va 2 s . c o m addOption("model", "m", "The path to the model built during training", true); addOption("labelcount", "lc", "total count of labels existent in the training set", true); addOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION, "max", "least number of stable iterations in classification layer when classifying", "10"); addOption(new DefaultOptionBuilder().withLongName(DefaultOptionCreator.MAPREDUCE_METHOD).withRequired(false) .withDescription("Run tests with map/reduce").withShortName("mr").create()); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } int labelcount = Integer.parseInt(getOption("labelcount")); iterations = Integer.parseInt(getOption("maxIter")); //check models existence Path model = new Path(parsedArgs.get("--model")); if (!model.getFileSystem(getConf()).exists(model)) { log.error("Model file does not exist!"); return -1; } //create the list of all labels List<String> lables = new ArrayList<String>(); for (int i = 0; i < labelcount; i++) lables.add(String.valueOf(i)); FileSystem fs = getInputPath().getFileSystem(getConf()); ResultAnalyzer analyzer = new ResultAnalyzer(lables, "-1"); //initiate the paths to the test batches Path[] batches; if (fs.isFile(getInputPath())) batches = new Path[] { getInputPath() }; else { FileStatus[] stati = fs.listStatus(getInputPath()); batches = new Path[stati.length]; for (int i = 0; i < stati.length; i++) { batches[i] = stati[i].getPath(); } } if (hasOption("mapreduce")) HadoopUtil.delete(getConf(), getTempPath("testresults")); for (Path input : batches) { if (hasOption("mapreduce")) { HadoopUtil.cacheFiles(model, getConf()); //the output key is the expected value, the output value are the scores for all the labels Job testJob = prepareJob(input, getTempPath("testresults"), SequenceFileInputFormat.class, TestRBMClassifierMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class); testJob.getConfiguration().set("maxIter", String.valueOf(iterations)); testJob.waitForCompletion(true); //loop over the results and create the confusion matrix SequenceFileDirIterable<IntWritable, VectorWritable> dirIterable = new SequenceFileDirIterable<IntWritable, VectorWritable>( getTempPath("testresults"), PathType.LIST, PathFilters.partFilter(), getConf()); analyzeResults(dirIterable, analyzer); } else { //test job locally runTestsLocally(model, analyzer, input); } } //output the result of the tests log.info("RBMClassifier Results: {}", analyzer); //stop all running threads if (executor != null) executor.shutdownNow(); return 0; }
From source file:org.apache.mahout.classifier.rbm.training.RBMClassifierTrainingJob.java
@Override public int run(String[] args) throws Exception { addInputOption();//from ww w .j av a2 s. co m addOutputOption(); addOption("epochs", "e", "number of training epochs through the trainingset", true); addOption("structure", "s", "comma-separated list of layer sizes", false); addOption("labelcount", "lc", "total count of labels existent in the training set", true); addOption("learningrate", "lr", "learning rate at the beginning of training", "0.005"); addOption("momentum", "m", "momentum of learning at the beginning", "0.5"); addOption("rbmnr", "nr", "rbm to train, < 0 means train all", "-1"); addOption("nrgibbs", "gn", "number of gibbs sampling used in contrastive divergence", "5"); addOption(new DefaultOptionBuilder().withLongName(DefaultOptionCreator.MAPREDUCE_METHOD).withRequired(false) .withDescription("Run training with map/reduce").withShortName("mr").create()); addOption(new DefaultOptionBuilder().withLongName("nogreedy").withRequired(false) .withDescription("Don't run greedy pre training").withShortName("ng").create()); addOption(new DefaultOptionBuilder().withLongName("nofinetuning").withRequired(false) .withDescription("Don't run fine tuning at the end").withShortName("nf").create()); addOption(new DefaultOptionBuilder().withLongName("nobiases").withRequired(false) .withDescription("Don't initialize biases").withShortName("nb").create()); addOption(new DefaultOptionBuilder().withLongName("monitor").withRequired(false) .withDescription("If present, errors can be monitored in cosole").withShortName("mon").create()); addOption(DefaultOptionCreator.overwriteOption().create()); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } Path input = getInputPath(); Path output = getOutputPath(); FileSystem fs = FileSystem.get(output.toUri(), getConf()); labelcount = Integer.parseInt(getOption("labelcount")); boolean local = !hasOption("mapreduce"); monitor = hasOption("monitor"); initbiases = !hasOption("nobiases"); finetuning = !hasOption("nofinetuning"); greedy = !hasOption("nogreedy"); if (fs.isFile(input)) batches = new Path[] { input }; else { FileStatus[] stati = fs.listStatus(input); batches = new Path[stati.length]; for (int i = 0; i < stati.length; i++) { batches[i] = stati[i].getPath(); } } epochs = Integer.valueOf(getOption("epochs")); learningrate = Double.parseDouble(getOption("learningrate")); momentum = Double.parseDouble(getOption("momentum")); rbmNrtoTrain = Integer.parseInt(getOption("rbmnr")); nrGibbsSampling = Integer.parseInt(getOption("nrgibbs")); boolean initialize = hasOption(DefaultOptionCreator.OVERWRITE_OPTION) || !fs.exists(output) || fs.listStatus(output).length <= 0; if (initialize) { String structure = getOption("structure"); if (structure == null || structure.isEmpty()) return -1; String[] layers = structure.split(","); if (layers.length < 2) { return -1; } int[] actualLayerSizes = new int[layers.length]; for (int i = 0; i < layers.length; i++) { actualLayerSizes[i] = Integer.parseInt(layers[i]); } rbmCl = new RBMClassifier(labelcount, actualLayerSizes); logger.info("New model initialized!"); } else { rbmCl = RBMClassifier.materialize(output, getConf()); logger.info("Model found and materialized!"); } HadoopUtil.setSerializations(getConf()); lastUpdate = new Matrix[rbmCl.getDbm().getRbmCount()]; if (initbiases) { //init biases! Vector biases = null; int counter = 0; for (Path batch : batches) { for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>( batch, getConf())) { if (biases == null) biases = record.getSecond().get().clone(); else biases.plus(record.getSecond().get()); counter++; } } if (biases == null) { logger.info("No training data found!"); return -1; } rbmCl.getDbm().getLayer(0).setBiases(biases.divide(counter)); logger.info("Biases initialized"); } //greedy pre training with gradually decreasing learningrates if (greedy) { if (!local) rbmCl.serialize(output, getConf()); double tempLearningrate = learningrate; if (rbmNrtoTrain < 0) //train all rbms for (int rbmNr = 0; rbmNr < rbmCl.getDbm().getRbmCount(); rbmNr++) { tempLearningrate = learningrate; //double weights if dbm was materialized, because it was halved after greedy pretraining if (!initialize && rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNr)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNr)).getWeightMatrix().times(2)); } for (int j = 0; j < epochs; j++) { logger.info("Greedy training, epoch " + (j + 1) + "\nCurrent learningrate: " + tempLearningrate); for (int b = 0; b < batches.length; b++) { tempLearningrate -= learningrate / (epochs * batches.length + epochs); if (local) { if (!trainGreedySeq(rbmNr, batches[b], j, tempLearningrate)) return -1; } else if (!trainGreedyMR(rbmNr, batches[b], j, tempLearningrate)) return -1; if (monitor && (batches.length > 19) && (b + 1) % (batches.length / 20) == 0) logger.info(rbmNr + "-RBM: " + Math.round(((double) b + 1) / batches.length * 100.0) + "% in epoch done!"); } logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training on rbm number " + rbmNr + " is done!"); if (monitor) { double error = rbmError(batches[0], rbmNr); logger.info( "Average reconstruction error on batch " + batches[0].getName() + ": " + error); } rbmCl.serialize(output, getConf()); } //weight normalization to avoid double counting if (rbmNr > 0 && rbmNr < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5)); } } else { //double weights if dbm was materialized, because it was halved after greedy pretraining if (!initialize && rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(2)); } //train just wanted rbm for (int j = 0; j < epochs; j++) { logger.info( "Greedy training, epoch " + (j + 1) + "\nCurrent learningrate: " + tempLearningrate); for (int b = 0; b < batches.length; b++) { tempLearningrate -= learningrate / (epochs * batches.length + epochs); if (local) { if (!trainGreedySeq(rbmNrtoTrain, batches[b], j, tempLearningrate)) return -1; } else if (!trainGreedyMR(rbmNrtoTrain, batches[b], j, tempLearningrate)) return -1; if (monitor && (batches.length > 19) && (b + 1) % (batches.length / 20) == 0) logger.info(rbmNrtoTrain + "-RBM: " + Math.round(((double) b + 1) / batches.length * 100.0) + "% in epoch done!"); } logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training is done!"); if (monitor) { double error = rbmError(batches[0], rbmNrtoTrain); logger.info("Average reconstruction error on batch " + batches[0].getName() + ": " + error); } } //weight normalization to avoid double counting if (rbmNrtoTrain > 0 && rbmNrtoTrain < rbmCl.getDbm().getRbmCount() - 1) { ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM) rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5)); } } rbmCl.serialize(output, getConf()); logger.info("Pretraining done and model written to output"); } if (finetuning) { DeepBoltzmannMachine multiLayerDbm = null; double tempLearningrate = learningrate; //finetuning job for (int j = 0; j < epochs; j++) { for (int b = 0; b < batches.length; b++) { multiLayerDbm = rbmCl.initializeMultiLayerNN(); logger.info("Finetuning on batch " + batches[b].getName() + "\nCurrent learningrate: " + tempLearningrate); tempLearningrate -= learningrate / (epochs * batches.length + epochs); if (local) { if (!finetuneSeq(batches[b], j, multiLayerDbm, tempLearningrate)) return -1; } else if (!fintuneMR(batches[b], j, tempLearningrate)) return -1; logger.info("Finetuning: " + Math.round(((double) b + 1) / batches.length * 100.0) + "% in epoch done!"); } logger.info(Math.round(((double) j + 1) / epochs * 100) + "% of training is done!"); if (monitor) { double error = feedForwardError(multiLayerDbm, batches[0]); logger.info("Average discriminative error on batch " + batches[0].getName() + ": " + error); } } //final serialization rbmCl.serialize(output, getConf()); logger.info("RBM finetuning done and model written to output"); } if (executor != null) executor.shutdownNow(); return 0; }