etc.aloe.cscw2013.TrainingImpl.java Source code

Java tutorial

Introduction

Here is the source code for etc.aloe.cscw2013.TrainingImpl.java

Source

/*
 * This file is part of ALOE.
 *
 * ALOE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
    
 * ALOE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
    
 * You should have received a copy of the GNU General Public License
 * along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
 */
package etc.aloe.cscw2013;

import etc.aloe.data.ExampleSet;
import etc.aloe.processes.Training;
import weka.classifiers.Classifier;
import weka.classifiers.CostMatrix;
import weka.classifiers.functions.SMO;
import weka.classifiers.meta.CostSensitiveClassifier;
import weka.core.Utils;

/**
 * Performs basic training of a linear support vector machine classifier.
 *
 * @author Michael Brooks <mjbrooks@uw.edu>
 */
public class TrainingImpl implements Training {

    private static final String SMO_OPTIONS = "-C 1.0 -L 0.0010 -P 1.0E-12 -N 0 -V -1 -W 1 -K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\"";
    private boolean buildLogisticModel = false;
    private double falsePositiveCost = 1;
    private double falseNegativeCost = 1;
    private boolean useReweighting = false;
    private boolean useCostTraining = false;

    public TrainingImpl() {
    }

    public TrainingImpl(double falsePositiveCost, double falseNegativeCost, boolean useReweighting) {
        this.falsePositiveCost = falsePositiveCost;
        this.falseNegativeCost = falseNegativeCost;
        this.useReweighting = useReweighting;
        this.useCostTraining = true;
    }

    public boolean isBuildLogisticModel() {
        return buildLogisticModel;
    }

    public void setBuildLogisticModel(boolean buildLogisticModel) {
        this.buildLogisticModel = buildLogisticModel;
    }

    public double getFalsePositiveCost() {
        return falsePositiveCost;
    }

    public void setFalsePositiveCost(double falsePositiveCost) {
        this.falsePositiveCost = falsePositiveCost;
    }

    public double getFalseNegativeCost() {
        return falseNegativeCost;
    }

    public void setFalseNegativeCost(double falseNegativeCost) {
        this.falseNegativeCost = falseNegativeCost;
    }

    public boolean isUseReweighting() {
        return useReweighting;
    }

    public void setUseReweighting(boolean useReweighting) {
        this.useReweighting = useReweighting;
    }

    public boolean isUseCostTraining() {
        return useCostTraining;
    }

    public void setUseCostTraining(boolean useCostTraining) {
        this.useCostTraining = useCostTraining;
    }

    @Override
    public WekaModel train(ExampleSet examples) {
        System.out.println("SMO Options: " + SMO_OPTIONS);
        SMO smo = new SMO();
        try {
            smo.setOptions(Utils.splitOptions(SMO_OPTIONS));
        } catch (Exception ex) {
            System.err.println("Unable to configure SMO.");
            System.err.println("\t" + ex.getMessage());
            return null;
        }

        //Build logistic models if desired
        smo.setBuildLogisticModels(isBuildLogisticModel());

        Classifier classifier = smo;

        if (useCostTraining) {
            CostSensitiveClassifier cost = new CostSensitiveClassifier();
            cost.setClassifier(smo);
            CostMatrix matrix = new CostMatrix(2);
            matrix.setElement(0, 0, 0);
            matrix.setElement(0, 1, falsePositiveCost);
            matrix.setElement(1, 0, falseNegativeCost);
            matrix.setElement(1, 1, 0);
            cost.setCostMatrix(matrix);

            classifier = cost;

            System.out.print("Wrapping SMO in CostSensitiveClassifier " + matrix.toMatlab());

            if (useReweighting) {
                cost.setMinimizeExpectedCost(false);
                System.out.println(" using re-weighting.");
            } else {
                cost.setMinimizeExpectedCost(true);
                System.out.println(" using min-cost criterion.");
            }
        }

        try {
            System.out.print("Training SMO on " + examples.size() + " examples... ");
            classifier.buildClassifier(examples.getInstances());
            System.out.println("done.");

            WekaModel model = new WekaModel(classifier);
            return model;
        } catch (Exception ex) {
            System.err.println("Unable to train SMO.");
            System.err.println("\t" + ex.getMessage());
            return null;
        }
    }
}