pl.edu.icm.cermine.libsvm.training.SVMBodyBuilder.java Source code

Java tutorial

Introduction

Here is the source code for pl.edu.icm.cermine.libsvm.training.SVMBodyBuilder.java

Source

/**
 * This file is part of CERMINE project.
 * Copyright (c) 2011-2013 ICM-UW
 *
 * CERMINE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * CERMINE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with CERMINE. If not, see <http://www.gnu.org/licenses/>.
 */

package pl.edu.icm.cermine.libsvm.training;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import libsvm.svm_parameter;
import org.apache.commons.cli.*;
import pl.edu.icm.cermine.content.filtering.ContentFilterTools;
import pl.edu.icm.cermine.exception.AnalysisException;
import pl.edu.icm.cermine.structure.model.BxPage;
import pl.edu.icm.cermine.structure.model.BxZone;
import pl.edu.icm.cermine.structure.model.BxZoneLabel;
import pl.edu.icm.cermine.structure.model.BxZoneLabelCategory;
import pl.edu.icm.cermine.tools.BxDocUtils.DocumentsIterator;
import pl.edu.icm.cermine.tools.classification.general.*;
import pl.edu.icm.cermine.tools.classification.svm.SVMZoneClassifier;

public class SVMBodyBuilder {

    protected static SVMZoneClassifier getZoneClassifier(List<TrainingSample<BxZoneLabel>> trainingSamples,
            int kernelType, double gamma, double C, int degree) throws IOException {
        trainingSamples = ClassificationUtils.filterElements(trainingSamples, BxZoneLabelCategory.CAT_BODY);
        System.out.println("Training samples " + trainingSamples.size());

        PenaltyCalculator pc = new PenaltyCalculator(trainingSamples);
        int[] intClasses = new int[pc.getClasses().size()];
        double[] classesWeights = new double[pc.getClasses().size()];

        int labelIdx = 0;
        for (BxZoneLabel label : pc.getClasses()) {
            intClasses[labelIdx] = label.ordinal();
            classesWeights[labelIdx] = pc.getPenaltyWeigth(label);
            ++labelIdx;
        }

        FeatureVectorBuilder<BxZone, BxPage> featureVectorBuilder = ContentFilterTools.VECTOR_BUILDER;
        SVMZoneClassifier zoneClassifier = new SVMZoneClassifier(featureVectorBuilder);
        svm_parameter param = SVMZoneClassifier.getDefaultParam();
        param.svm_type = svm_parameter.C_SVC;
        param.gamma = gamma;
        param.C = C;
        System.out.println(degree);
        param.degree = degree;
        param.kernel_type = kernelType;
        param.weight = classesWeights;
        param.weight_label = intClasses;

        zoneClassifier.setParameter(param);
        zoneClassifier.buildClassifier(trainingSamples);
        zoneClassifier.printWeigths(featureVectorBuilder);
        zoneClassifier.saveModel("svm_body_classifier");
        return zoneClassifier;
    }

    public static void main(String[] args) throws ParseException, AnalysisException, IOException {
        Options options = new Options();
        options.addOption("input", true, "input path");
        options.addOption("output", true, "output model path");
        options.addOption("kernel", true, "kernel type");
        options.addOption("g", true, "gamma");
        options.addOption("C", true, "C");
        options.addOption("degree", true, "degree");
        options.addOption("cross", false, "");
        options.addOption("ext", true, "degree");

        CommandLineParser parser = new GnuParser();
        CommandLine line = parser.parse(options, args);
        if (!line.hasOption("input") || !line.hasOption("output")) {
            System.err.println(
                    "Usage: SVMMetadataBuilder [-kernel <kernel type>] [-d <degree>] [-g <gamma>] [-C <error cost>] [-ext <extension>] -input <input dir> -output <path>");
            System.exit(1);
        }

        Double C = 8.0;
        if (line.hasOption("C")) {
            C = Double.valueOf(line.getOptionValue("C"));
        }
        Double gamma = 0.5;
        if (line.hasOption("g")) {
            gamma = Double.valueOf(line.getOptionValue("g"));
        }
        String inDir = line.getOptionValue("input");
        String outFile = line.getOptionValue("output");
        String degreeStr = line.getOptionValue("degree");
        Integer degree = -1;
        if (degreeStr != null && !degreeStr.isEmpty()) {
            degree = Integer.valueOf(degreeStr);
        }
        Integer kernelType = svm_parameter.RBF;
        if (line.hasOption("kernel")) {
            switch (Integer.valueOf(line.getOptionValue("kernel"))) {
            case 0:
                kernelType = svm_parameter.LINEAR;
                break;
            case 1:
                kernelType = svm_parameter.POLY;
                break;
            case 2:
                kernelType = svm_parameter.RBF;
                break;
            case 3:
                kernelType = svm_parameter.SIGMOID;
                break;
            default:
                throw new IllegalArgumentException("Invalid kernel value provided");
            }
        }
        if (kernelType == svm_parameter.POLY && degree == null) {
            System.err.println("Polynomial kernel requires the -degree option to be specified");
            System.exit(1);
        }

        String ext = "cxml";
        if (line.hasOption("ext")) {
            ext = line.getOptionValue("ext");
        }

        Map<BxZoneLabel, BxZoneLabel> map = new EnumMap<BxZoneLabel, BxZoneLabel>(BxZoneLabel.class);
        map.put(BxZoneLabel.BODY_ACKNOWLEDGMENT, BxZoneLabel.BODY_CONTENT);
        map.put(BxZoneLabel.BODY_CONFLICT_STMT, BxZoneLabel.BODY_CONTENT);
        map.put(BxZoneLabel.BODY_EQUATION, BxZoneLabel.BODY_JUNK);
        map.put(BxZoneLabel.BODY_FIGURE, BxZoneLabel.BODY_JUNK);
        map.put(BxZoneLabel.BODY_GLOSSARY, BxZoneLabel.BODY_JUNK);
        map.put(BxZoneLabel.BODY_TABLE, BxZoneLabel.BODY_JUNK);

        if (!line.hasOption("cross")) {
            File input = new File(inDir);
            List<TrainingSample<BxZoneLabel>> trainingSamples;
            if (input.isDirectory()) {
                DocumentsIterator it = new DocumentsIterator(inDir, ext);
                FeatureVectorBuilder<BxZone, BxPage> featureVectorBuilder = ContentFilterTools.VECTOR_BUILDER;
                trainingSamples = BxDocsToTrainingSamplesConverter.getZoneTrainingSamples(it.iterator(),
                        featureVectorBuilder, map);
            } else {
                trainingSamples = SVMZoneClassifier.loadProblem(inDir, ContentFilterTools.VECTOR_BUILDER);
            }

            trainingSamples = ClassificationUtils.filterElements(trainingSamples, BxZoneLabelCategory.CAT_BODY);
            SVMZoneClassifier classifier = getZoneClassifier(trainingSamples, kernelType, gamma, C, degree);
            classifier.saveModel(outFile);
        } else {
            int foldness = 5;
            List<TrainingSample<BxZoneLabel>>[] trainingSamplesSet = new List[foldness];

            for (int i = 0; i < foldness; i++) {

                File input = new File(inDir + "/" + i);
                List<TrainingSample<BxZoneLabel>> trainingSamples;
                if (input.isDirectory()) {
                    DocumentsIterator it = new DocumentsIterator(inDir + "/" + i, ext);

                    FeatureVectorBuilder<BxZone, BxPage> featureVectorBuilder = ContentFilterTools.VECTOR_BUILDER;
                    trainingSamples = BxDocsToTrainingSamplesConverter.getZoneTrainingSamples(it.iterator(),
                            featureVectorBuilder, map);
                } else {
                    trainingSamples = SVMZoneClassifier.loadProblem(inDir + "/" + i,
                            ContentFilterTools.VECTOR_BUILDER);
                }

                trainingSamples = ClassificationUtils.filterElements(trainingSamples, BxZoneLabelCategory.CAT_BODY);
                trainingSamplesSet[i] = trainingSamples;
            }

            for (int i = 0; i < foldness; i++) {
                List<TrainingSample<BxZoneLabel>> trainingSamples = new ArrayList<TrainingSample<BxZoneLabel>>();
                for (int j = 0; j < foldness; j++) {
                    if (i != j) {
                        trainingSamples.addAll(trainingSamplesSet[j]);
                    }
                }

                SVMZoneClassifier classifier = getZoneClassifier(trainingSamples, kernelType, gamma, C, degree);
                classifier.saveModel(outFile + "-" + i);
            }
        }
    }

}