com.imolinfo.offline.CrossFoldValidation.java Source code

Java tutorial

Introduction

Here is the source code for com.imolinfo.offline.CrossFoldValidation.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.imolinfo.offline;

import com.imolinfo.model.Document;
import com.imolinfo.plug.iface.DocumentProvider;
import com.imolinfo.plug.iface.DocumentToLabeledPoint;
import com.imolinfo.plug.impl.DocumentPersistProvider;
import com.imolinfo.plug.impl.DocumentStandardCleaner;
import com.imolinfo.plug.impl.DocumentToTFIDFLabeledPoint;
import com.imolinfo.util.GlobalVariable;
import com.imolinfo.plug.clm.SVMOneVsAll;
import com.imolinfo.plug.iface.DocumentClassifier;
import com.imolinfo.util.TestUtils;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import org.apache.commons.io.FileUtils;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;

/**
 *
 * @author renzo
 */
public class CrossFoldValidation {

    public static void main(String[] args)
            throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {

        Properties p = new Properties();
        p.load(new FileInputStream("runtime.properties"));
        GlobalVariable.getInstance().setProperties(p);
        String classificatorName = Class.forName(p.getProperty("ClassificationModel")).newInstance().toString();

        SparkConf conf = new SparkConf().setAppName("CrossFoldValidation: " + classificatorName);
        final JavaSparkContext jsc = new JavaSparkContext(conf);

        invokePipeline(jsc);
    }

    public static void invokePipeline(JavaSparkContext jsc)
            throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {

        JavaRDD<Document> trainingSet, testSet;
        Logger.getLogger("org").setLevel(Level.OFF);
        Logger.getLogger("akka").setLevel(Level.OFF);

        // CARICAMENTO DEL DATASET E PARTIZIONE IN PIU PEZZETTINI
        Properties prop = GlobalVariable.getInstance().getProperties();

        DocumentProvider tp = (DocumentProvider) Class.forName(prop.getProperty("sourceClass")).newInstance();
        JavaRDD<Document> corpus = tp.getTextFromDs(jsc,
                prop.getProperty("splitDatasetPath") + "/" + prop.getProperty("corpus"));
        DocumentStandardCleaner tc = new DocumentStandardCleaner();
        corpus = tc.cleanData(corpus);
        corpus.cache();
        ArrayList<JavaRDD<Document>[]> LabelPortionSplits = new ArrayList<JavaRDD<Document>[]>();
        for (final String label : GlobalVariable.getInstance().getIntLabelMap().values()) {
            JavaRDD<Document> labelPortion = corpus.filter(new Function<Document, Boolean>() {

                @Override
                public Boolean call(Document arg0) throws Exception {
                    return arg0.getLabel().equals(label);
                }
            });
            labelPortion.cache();
            JavaRDD<Document>[] labelPortionSplit = labelPortion
                    .randomSplit(new double[] { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1 });
            LabelPortionSplits.add(labelPortionSplit);
        }

        // metto in cache tutti dataset piccolini divisi per label e poi partizionati secondo il fattore K...
        for (JavaRDD<Document>[] line : LabelPortionSplits) {
            for (int j = 0; j < line.length; j++) {
                line[j].cache();
            }
        }
        JavaRDD<Document>[] splitDataset = LabelPortionSplits.get(0);
        ArrayList<JavaRDD<Document>> otherSplits = new ArrayList<JavaRDD<Document>>();

        String result = "";
        // combino i dataset in modo da realizzare partizioni stratificati
        for (int i = 0; i < splitDataset.length; i++) {
            otherSplits.clear();
            for (int j = 1; j < LabelPortionSplits.size(); j++) {
                JavaRDD<Document>[] target = LabelPortionSplits.get(j);
                target[i].cache();
                otherSplits.add(target[i]);
            }
            splitDataset[i] = jsc.union(splitDataset[i], otherSplits);
            splitDataset[i].cache();
        }
        ArrayList<JavaRDD<Document>> trainingSetList = new ArrayList<JavaRDD<Document>>();
        // metto assieme le partizioni stratificate in modo da iterare training e testset
        List<JavaRDD<Document>> splitDocuments = Arrays.asList(splitDataset);
        for (int i = 0; i < splitDocuments.size(); i++) {
            testSet = splitDocuments.get(i);
            trainingSet = corpus.subtract(testSet);
            result = result + "\n" + trainAndTest(jsc, trainingSet, testSet);

        }
        System.out.println(result);

    }

    private static String trainAndTest(JavaSparkContext jsc, JavaRDD<Document> trainingSet,
            JavaRDD<Document> testSet)
            throws ClassNotFoundException, InstantiationException, IllegalAccessException {

        Properties prop = GlobalVariable.getInstance().getProperties();
        DocumentToLabeledPoint tl = new DocumentToTFIDFLabeledPoint();
        trainingSet = tl.vectorize(trainingSet);
        JavaRDD<LabeledPoint> features = tl.convert(trainingSet);

        DocumentClassifier d = (DocumentClassifier) Class.forName(prop.getProperty("ClassificationModel"))
                .newInstance();
        d.train(jsc, features);

        DocumentToLabeledPoint tlt = new DocumentToTFIDFLabeledPoint();
        tlt.setIDFModel(tl.getIDFModel());
        JavaRDD<Document> alba = tlt.vectorize(testSet);
        JavaRDD<LabeledPoint> featureDataTest = tlt.convert(alba);

        TestUtils.analyze(featureDataTest, d);
        System.out.println("TRAINING: " + trainingSet.count() + " TEST: " + testSet.count() + "(" + d.toString()
                + ") " + Arrays.asList(TestUtils.getOverallStats()) + " | " + Arrays.asList(TestUtils.getWstats()));
        return "TRAINING:" + trainingSet.count() + " TEST:" + testSet.count() + " SVMResult "
                + Arrays.asList(TestUtils.getOverallStats()) + " | " + Arrays.asList(TestUtils.getWstats());
    }

}