apps.quantification.LearnQuantificationSVMLight.java Source code

Java tutorial

Introduction

Here is the source code for apps.quantification.LearnQuantificationSVMLight.java

Source

/*
 * This file is part of JaTeCS.
 *
 * JaTeCS is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JaTeCS is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with JaTeCS.  If not, see <http://www.gnu.org/licenses/>.
 *
 * The software has been mainly developed by (in alphabetical order):
 * - Andrea Esuli (andrea.esuli@isti.cnr.it)
 * - Tiziano Fagni (tiziano.fagni@isti.cnr.it)
 * - Alejandro Moreo Fernndez (alejandro.moreo@isti.cnr.it)
 * Other past contributors were:
 * - Giacomo Berardi (giacomo.berardi@isti.cnr.it)
 */

package apps.quantification;

import gnu.trove.TShortDoubleHashMap;
import it.cnr.jatecs.classification.ClassificationMode;
import it.cnr.jatecs.classification.interfaces.IDataManager;
import it.cnr.jatecs.classification.svmlight.SvmLightClassifierCustomizer;
import it.cnr.jatecs.classification.svmlight.SvmLightDataManager;
import it.cnr.jatecs.classification.svmlight.SvmLightLearner;
import it.cnr.jatecs.classification.svmlight.SvmLightLearnerCustomizer;
import it.cnr.jatecs.evaluation.ContingencyTable;
import it.cnr.jatecs.evaluation.ContingencyTableSet;
import it.cnr.jatecs.indexes.DB.interfaces.IIndex;
import it.cnr.jatecs.indexes.DB.troveCompact.TroveClassificationDBType;
import it.cnr.jatecs.indexes.DB.troveCompact.TroveContentDBType;
import it.cnr.jatecs.indexes.DB.troveCompact.TroveReadWriteHelper;
import it.cnr.jatecs.io.FileSystemStorageManager;
import it.cnr.jatecs.quantification.QuantificationLearner;
import it.cnr.jatecs.quantification.interfaces.IQuantifier;
import it.cnr.jatecs.quantification.scalingFunctions.LogisticFunction;
import it.cnr.jatecs.utils.IOperationStatusListener;
import it.cnr.jatecs.utils.Os;
import it.cnr.jatecs.utils.TextualProgressBar;
import org.apache.commons.cli.*;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;

/**
 * This app learns a pool of quantification models out of a training set using
 * SVM_light as the base learner.
 * <p>
 * The SVM_light learn executable must be locally available.
 *
 * @author Andrea Esuli
 */
public class LearnQuantificationSVMLight {

    public static void main(String[] args) throws IOException {
        String cmdLineSyntax = LearnQuantificationSVMLight.class.getName()
                + " [OPTIONS] <path to svm_light_learn> <path to svm_light_classify> <trainingIndexDirectory> <outputDirectory>";

        Options options = new Options();

        OptionBuilder.withArgName("f");
        OptionBuilder.withDescription("Number of folds");
        OptionBuilder.withLongOpt("f");
        OptionBuilder.isRequired(true);
        OptionBuilder.hasArg();
        options.addOption(OptionBuilder.create());

        OptionBuilder.withArgName("c");
        OptionBuilder.withDescription("The c value for svm_light (default 1)");
        OptionBuilder.withLongOpt("c");
        OptionBuilder.isRequired(false);
        OptionBuilder.hasArg();
        options.addOption(OptionBuilder.create());

        OptionBuilder.withArgName("k");
        OptionBuilder.withDescription("Kernel type (default 0: linear, 1: polynomial, 2: RBF, 3: sigmoid)");
        OptionBuilder.withLongOpt("k");
        OptionBuilder.isRequired(false);
        OptionBuilder.hasArg();
        options.addOption(OptionBuilder.create());

        OptionBuilder.withArgName("t");
        OptionBuilder.withDescription("Path for temporary files");
        OptionBuilder.withLongOpt("t");
        OptionBuilder.isRequired(false);
        OptionBuilder.hasArg();
        options.addOption(OptionBuilder.create());

        OptionBuilder.withArgName("v");
        OptionBuilder.withDescription("Verbose output");
        OptionBuilder.withLongOpt("v");
        OptionBuilder.isRequired(false);
        OptionBuilder.hasArg(false);
        options.addOption(OptionBuilder.create());

        OptionBuilder.withArgName("s");
        OptionBuilder.withDescription("Don't delete temporary training file in svm_light format (default: delete)");
        OptionBuilder.withLongOpt("s");
        OptionBuilder.isRequired(false);
        OptionBuilder.hasArg(false);
        options.addOption(OptionBuilder.create());

        SvmLightLearnerCustomizer classificationLearnerCustomizer = null;
        SvmLightClassifierCustomizer classificationCustomizer = null;

        int folds = -1;

        GnuParser parser = new GnuParser();
        String[] remainingArgs = null;
        try {
            CommandLine line = parser.parse(options, args);

            remainingArgs = line.getArgs();

            classificationLearnerCustomizer = new SvmLightLearnerCustomizer(remainingArgs[0]);
            classificationCustomizer = new SvmLightClassifierCustomizer(remainingArgs[1]);

            folds = Integer.parseInt(line.getOptionValue("f"));

            if (line.hasOption("c"))
                classificationLearnerCustomizer.setC(Float.parseFloat(line.getOptionValue("c")));

            if (line.hasOption("k")) {
                System.out.println("Kernel type: " + line.getOptionValue("k"));
                classificationLearnerCustomizer.setKernelType(Integer.parseInt(line.getOptionValue("k")));
            }

            if (line.hasOption("v"))
                classificationLearnerCustomizer.printSvmLightOutput(true);

            if (line.hasOption("s"))
                classificationLearnerCustomizer.setDeleteTrainingFiles(false);

            if (line.hasOption("t")) {
                classificationLearnerCustomizer.setTempPath(line.getOptionValue("t"));
                classificationCustomizer.setTempPath(line.getOptionValue("t"));
            }

        } catch (Exception exp) {
            System.err.println("Parsing failed.  Reason: " + exp.getMessage());
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp(cmdLineSyntax, options);
            System.exit(-1);
        }

        assert (classificationLearnerCustomizer != null);

        if (remainingArgs.length != 4) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp(cmdLineSyntax, options);
            System.exit(-1);
        }

        String indexFile = remainingArgs[2];

        File file = new File(indexFile);

        String indexName = file.getName();
        String indexPath = file.getParent();

        String outputPath = remainingArgs[3];

        SvmLightLearner classificationLearner = new SvmLightLearner();

        classificationLearner.setRuntimeCustomizer(classificationLearnerCustomizer);

        FileSystemStorageManager fssm = new FileSystemStorageManager(indexPath, false);
        fssm.open();

        IIndex training = TroveReadWriteHelper.readIndex(fssm, indexName, TroveContentDBType.Full,
                TroveClassificationDBType.Full);

        final TextualProgressBar progressBar = new TextualProgressBar("Learning the quantifiers");

        IOperationStatusListener status = new IOperationStatusListener() {

            @Override
            public void operationStatus(double percentage) {
                progressBar.signal((int) percentage);
            }
        };

        QuantificationLearner quantificationLearner = new QuantificationLearner(folds, classificationLearner,
                classificationLearnerCustomizer, classificationCustomizer, ClassificationMode.PER_CATEGORY,
                new LogisticFunction(), status);

        IQuantifier[] quantifiers = quantificationLearner.learn(training);

        File executableFile = new File(classificationLearnerCustomizer.getSvmLightLearnPath());
        IDataManager classifierDataManager = new SvmLightDataManager(new SvmLightClassifierCustomizer(
                executableFile.getParentFile().getAbsolutePath() + Os.pathSeparator() + "svm_light_classify"));
        String description = "_SVMLight_C-" + classificationLearnerCustomizer.getC() + "_K-"
                + classificationLearnerCustomizer.getKernelType();
        if (classificationLearnerCustomizer.getAdditionalParameters().length() > 0)
            description += "_" + classificationLearnerCustomizer.getAdditionalParameters();
        String quantifierPrefix = indexName + "_Quantifier-" + folds + description;

        FileSystemStorageManager fssmo = new FileSystemStorageManager(
                outputPath + File.separatorChar + quantifierPrefix, true);
        fssmo.open();
        QuantificationLearner.write(quantifiers, fssmo, classifierDataManager);
        fssmo.close();

        BufferedWriter bfs = new BufferedWriter(
                new FileWriter(outputPath + File.separatorChar + quantifierPrefix + "_rates.txt"));
        TShortDoubleHashMap simpleTPRs = quantificationLearner.getSimpleTPRs();
        TShortDoubleHashMap simpleFPRs = quantificationLearner.getSimpleFPRs();
        TShortDoubleHashMap scaledTPRs = quantificationLearner.getScaledTPRs();
        TShortDoubleHashMap scaledFPRs = quantificationLearner.getScaledFPRs();

        ContingencyTableSet contingencyTableSet = quantificationLearner.getContingencyTableSet();

        short[] cats = simpleTPRs.keys();
        for (int i = 0; i < cats.length; ++i) {
            short cat = cats[i];
            String catName = training.getCategoryDB().getCategoryName(cat);
            ContingencyTable contingencyTable = contingencyTableSet.getCategoryContingencyTable(cat);
            double simpleTPR = simpleTPRs.get(cat);
            double simpleFPR = simpleFPRs.get(cat);
            double scaledTPR = scaledTPRs.get(cat);
            double scaledFPR = scaledFPRs.get(cat);
            String line = quantifierPrefix + "\ttrain\tsimple\t" + catName + "\t" + cat + "\t"
                    + contingencyTable.tp() + "\t" + contingencyTable.fp() + "\t" + contingencyTable.fn() + "\t"
                    + contingencyTable.tn() + "\t" + simpleTPR + "\t" + simpleFPR + "\n";
            bfs.write(line);
            line = quantifierPrefix + "\ttrain\tscaled\t" + catName + "\t" + cat + "\t" + contingencyTable.tp()
                    + "\t" + contingencyTable.fp() + "\t" + contingencyTable.fn() + "\t" + contingencyTable.tn()
                    + "\t" + scaledTPR + "\t" + scaledFPR + "\n";
            bfs.write(line);
        }
        bfs.close();
    }
}