Java tutorial
/* * * **************** * This file is part of sparkboost software package (https://github.com/tizfa/sparkboost). * * Copyright 2016 Tiziano Fagni (tiziano.fagni@isti.cnr.it) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ****************** */ package it.tizianofagni.sparkboost; import org.apache.commons.cli.*; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; /** * @author Tiziano Fagni (tiziano.fagni@isti.cnr.it) */ public class MPBoostLearnerExe { public static void main(String[] args) { Options options = new Options(); options.addOption("b", "binaryProblem", false, "Indicate if the input dataset contains a binary problem and not a multilabel one"); options.addOption("z", "labels0based", false, "Indicate if the labels IDs in the dataset to classifyLibSvmWithResults are already assigned in the range [0, numLabels-1] included"); options.addOption("l", "enableSparkLogging", false, "Enable logging messages of Spark"); options.addOption("w", "windowsLocalModeFix", true, "Set the directory containing the winutils.exe command"); options.addOption("dp", "documentPartitions", true, "The number of document partitions"); options.addOption("fp", "featurePartitions", true, "The number of feature partitions"); options.addOption("lp", "labelPartitions", true, "The number of label partitions"); CommandLineParser parser = new BasicParser(); CommandLine cmd = null; String[] remainingArgs = null; try { cmd = parser.parse(options, args); remainingArgs = cmd.getArgs(); if (remainingArgs.length != 3) throw new ParseException("You need to specify all mandatory parameters"); } catch (ParseException e) { System.out.println("Parsing failed. Reason: " + e.getMessage()); HelpFormatter formatter = new HelpFormatter(); formatter.printHelp( MPBoostLearnerExe.class.getSimpleName() + " [OPTIONS] <inputFile> <outputFile> <numIterations>", options); System.exit(-1); } boolean binaryProblem = false; if (cmd.hasOption("b")) binaryProblem = true; boolean labels0Based = false; if (cmd.hasOption("z")) labels0Based = true; boolean enablingSparkLogging = false; if (cmd.hasOption("l")) enablingSparkLogging = true; if (cmd.hasOption("w")) { System.setProperty("hadoop.home.dir", cmd.getOptionValue("w")); } String inputFile = remainingArgs[0]; String outputFile = remainingArgs[1]; int numIterations = Integer.parseInt(remainingArgs[2]); long startTime = System.currentTimeMillis(); // Disable Spark logging. if (!enablingSparkLogging) { Logger.getLogger("org").setLevel(Level.OFF); Logger.getLogger("akka").setLevel(Level.OFF); } // Create and configure Spark context. SparkConf conf = new SparkConf().setAppName("Spark MPBoost learner"); JavaSparkContext sc = new JavaSparkContext(conf); // Create and configure learner. MpBoostLearner learner = new MpBoostLearner(sc); learner.setNumIterations(numIterations); if (cmd.hasOption("dp")) { learner.setNumDocumentsPartitions(Integer.parseInt(cmd.getOptionValue("dp"))); } if (cmd.hasOption("fp")) { learner.setNumFeaturesPartitions(Integer.parseInt(cmd.getOptionValue("fp"))); } if (cmd.hasOption("lp")) { learner.setNumLabelsPartitions(Integer.parseInt(cmd.getOptionValue("lp"))); } // Build classifier with MPBoost learner. BoostClassifier classifier = learner.buildModel(inputFile, labels0Based, binaryProblem); // Save classifier to disk. DataUtils.saveModel(sc, classifier, outputFile); long endTime = System.currentTimeMillis(); System.out.println("Execution time: " + (endTime - startTime) + " milliseconds."); } }