Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.imolinfo.offline; import com.imolinfo.model.Document; import com.imolinfo.plug.iface.DocumentProvider; import com.imolinfo.plug.iface.DocumentToLabeledPoint; import com.imolinfo.plug.impl.DocumentPersistProvider; import com.imolinfo.plug.impl.DocumentStandardCleaner; import com.imolinfo.plug.impl.DocumentToTFIDFLabeledPoint; import com.imolinfo.util.GlobalVariable; import com.imolinfo.plug.clm.SVMOneVsAll; import com.imolinfo.plug.iface.DocumentClassifier; import com.imolinfo.util.TestUtils; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Properties; import org.apache.commons.io.FileUtils; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.rdd.RDD; /** * * @author renzo */ public class CrossFoldValidation { public static void main(String[] args) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException { Properties p = new Properties(); p.load(new FileInputStream("runtime.properties")); GlobalVariable.getInstance().setProperties(p); String classificatorName = Class.forName(p.getProperty("ClassificationModel")).newInstance().toString(); SparkConf conf = new SparkConf().setAppName("CrossFoldValidation: " + classificatorName); final JavaSparkContext jsc = new JavaSparkContext(conf); invokePipeline(jsc); } public static void invokePipeline(JavaSparkContext jsc) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException { JavaRDD<Document> trainingSet, testSet; Logger.getLogger("org").setLevel(Level.OFF); Logger.getLogger("akka").setLevel(Level.OFF); // CARICAMENTO DEL DATASET E PARTIZIONE IN PIU PEZZETTINI Properties prop = GlobalVariable.getInstance().getProperties(); DocumentProvider tp = (DocumentProvider) Class.forName(prop.getProperty("sourceClass")).newInstance(); JavaRDD<Document> corpus = tp.getTextFromDs(jsc, prop.getProperty("splitDatasetPath") + "/" + prop.getProperty("corpus")); DocumentStandardCleaner tc = new DocumentStandardCleaner(); corpus = tc.cleanData(corpus); corpus.cache(); ArrayList<JavaRDD<Document>[]> LabelPortionSplits = new ArrayList<JavaRDD<Document>[]>(); for (final String label : GlobalVariable.getInstance().getIntLabelMap().values()) { JavaRDD<Document> labelPortion = corpus.filter(new Function<Document, Boolean>() { @Override public Boolean call(Document arg0) throws Exception { return arg0.getLabel().equals(label); } }); labelPortion.cache(); JavaRDD<Document>[] labelPortionSplit = labelPortion .randomSplit(new double[] { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1 }); LabelPortionSplits.add(labelPortionSplit); } // metto in cache tutti dataset piccolini divisi per label e poi partizionati secondo il fattore K... for (JavaRDD<Document>[] line : LabelPortionSplits) { for (int j = 0; j < line.length; j++) { line[j].cache(); } } JavaRDD<Document>[] splitDataset = LabelPortionSplits.get(0); ArrayList<JavaRDD<Document>> otherSplits = new ArrayList<JavaRDD<Document>>(); String result = ""; // combino i dataset in modo da realizzare partizioni stratificati for (int i = 0; i < splitDataset.length; i++) { otherSplits.clear(); for (int j = 1; j < LabelPortionSplits.size(); j++) { JavaRDD<Document>[] target = LabelPortionSplits.get(j); target[i].cache(); otherSplits.add(target[i]); } splitDataset[i] = jsc.union(splitDataset[i], otherSplits); splitDataset[i].cache(); } ArrayList<JavaRDD<Document>> trainingSetList = new ArrayList<JavaRDD<Document>>(); // metto assieme le partizioni stratificate in modo da iterare training e testset List<JavaRDD<Document>> splitDocuments = Arrays.asList(splitDataset); for (int i = 0; i < splitDocuments.size(); i++) { testSet = splitDocuments.get(i); trainingSet = corpus.subtract(testSet); result = result + "\n" + trainAndTest(jsc, trainingSet, testSet); } System.out.println(result); } private static String trainAndTest(JavaSparkContext jsc, JavaRDD<Document> trainingSet, JavaRDD<Document> testSet) throws ClassNotFoundException, InstantiationException, IllegalAccessException { Properties prop = GlobalVariable.getInstance().getProperties(); DocumentToLabeledPoint tl = new DocumentToTFIDFLabeledPoint(); trainingSet = tl.vectorize(trainingSet); JavaRDD<LabeledPoint> features = tl.convert(trainingSet); DocumentClassifier d = (DocumentClassifier) Class.forName(prop.getProperty("ClassificationModel")) .newInstance(); d.train(jsc, features); DocumentToLabeledPoint tlt = new DocumentToTFIDFLabeledPoint(); tlt.setIDFModel(tl.getIDFModel()); JavaRDD<Document> alba = tlt.vectorize(testSet); JavaRDD<LabeledPoint> featureDataTest = tlt.convert(alba); TestUtils.analyze(featureDataTest, d); System.out.println("TRAINING: " + trainingSet.count() + " TEST: " + testSet.count() + "(" + d.toString() + ") " + Arrays.asList(TestUtils.getOverallStats()) + " | " + Arrays.asList(TestUtils.getWstats())); return "TRAINING:" + trainingSet.count() + " TEST:" + testSet.count() + " SVMResult " + Arrays.asList(TestUtils.getOverallStats()) + " | " + Arrays.asList(TestUtils.getWstats()); } }