Java tutorial
/* * This file is part of ALOE. * * ALOE is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * ALOE is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * You should have received a copy of the GNU General Public License * along with ALOE. If not, see <http://www.gnu.org/licenses/>. * * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl) */ package etc.aloe.cscw2013; import etc.aloe.data.ExampleSet; import etc.aloe.processes.Training; import weka.classifiers.Classifier; import weka.classifiers.CostMatrix; import weka.classifiers.functions.SMO; import weka.classifiers.meta.CostSensitiveClassifier; import weka.core.Utils; /** * Performs basic training of a linear support vector machine classifier. * * @author Michael Brooks <mjbrooks@uw.edu> */ public class TrainingImpl implements Training { private static final String SMO_OPTIONS = "-C 1.0 -L 0.0010 -P 1.0E-12 -N 0 -V -1 -W 1 -K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\""; private boolean buildLogisticModel = false; private double falsePositiveCost = 1; private double falseNegativeCost = 1; private boolean useReweighting = false; private boolean useCostTraining = false; public TrainingImpl() { } public TrainingImpl(double falsePositiveCost, double falseNegativeCost, boolean useReweighting) { this.falsePositiveCost = falsePositiveCost; this.falseNegativeCost = falseNegativeCost; this.useReweighting = useReweighting; this.useCostTraining = true; } public boolean isBuildLogisticModel() { return buildLogisticModel; } public void setBuildLogisticModel(boolean buildLogisticModel) { this.buildLogisticModel = buildLogisticModel; } public double getFalsePositiveCost() { return falsePositiveCost; } public void setFalsePositiveCost(double falsePositiveCost) { this.falsePositiveCost = falsePositiveCost; } public double getFalseNegativeCost() { return falseNegativeCost; } public void setFalseNegativeCost(double falseNegativeCost) { this.falseNegativeCost = falseNegativeCost; } public boolean isUseReweighting() { return useReweighting; } public void setUseReweighting(boolean useReweighting) { this.useReweighting = useReweighting; } public boolean isUseCostTraining() { return useCostTraining; } public void setUseCostTraining(boolean useCostTraining) { this.useCostTraining = useCostTraining; } @Override public WekaModel train(ExampleSet examples) { System.out.println("SMO Options: " + SMO_OPTIONS); SMO smo = new SMO(); try { smo.setOptions(Utils.splitOptions(SMO_OPTIONS)); } catch (Exception ex) { System.err.println("Unable to configure SMO."); System.err.println("\t" + ex.getMessage()); return null; } //Build logistic models if desired smo.setBuildLogisticModels(isBuildLogisticModel()); Classifier classifier = smo; if (useCostTraining) { CostSensitiveClassifier cost = new CostSensitiveClassifier(); cost.setClassifier(smo); CostMatrix matrix = new CostMatrix(2); matrix.setElement(0, 0, 0); matrix.setElement(0, 1, falsePositiveCost); matrix.setElement(1, 0, falseNegativeCost); matrix.setElement(1, 1, 0); cost.setCostMatrix(matrix); classifier = cost; System.out.print("Wrapping SMO in CostSensitiveClassifier " + matrix.toMatlab()); if (useReweighting) { cost.setMinimizeExpectedCost(false); System.out.println(" using re-weighting."); } else { cost.setMinimizeExpectedCost(true); System.out.println(" using min-cost criterion."); } } try { System.out.print("Training SMO on " + examples.size() + " examples... "); classifier.buildClassifier(examples.getInstances()); System.out.println("done."); WekaModel model = new WekaModel(classifier); return model; } catch (Exception ex) { System.err.println("Unable to train SMO."); System.err.println("\t" + ex.getMessage()); return null; } } }