Java tutorial
/* * NaiveBayes.java * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * */ package moa.reduction.core; import weka.attributeSelection.*; import moa.classifiers.AbstractClassifier; import moa.classifiers.bayes.NaiveBayes; import moa.classifiers.core.attributeclassobservers.AttributeClassObserver; import moa.classifiers.functions.SGDMultiClass; import moa.classifiers.trees.HoeffdingTree; import moa.core.AutoExpandVector; import moa.core.Measurement; import moa.core.StringUtils; import moa.core.TimingUtils; import moa.reduction.bayes.IDAdiscretize; import moa.reduction.bayes.IFFDdiscretize; import moa.reduction.bayes.IncrInfoThAttributeEval; import moa.reduction.bayes.LOFDiscretizer; import moa.reduction.bayes.OCdiscretize; import moa.reduction.bayes.OFSGDAttributeEval; import moa.reduction.bayes.PIDdiscretize; import java.util.ArrayList; import java.util.HashSet; import java.util.Set; import com.github.javacliparser.IntOption; import com.yahoo.labs.samoa.instances.Instance; import com.yahoo.labs.samoa.instances.WekaToSamoaInstanceConverter; import weka.core.Attribute; /** * Wrapper classifier with several preprocessing methods. * * <p>Performs classic bayesian prediction or multinomial SGD-based linear classification. * * @author Sergio Ramirez (sramirez@decsai.ugr.es) * @version $Revision: 2 $ */ public class ReductionClassifier extends AbstractClassifier { private static final long serialVersionUID = 1L; @Override public String getPurposeString() { return "Wrapper classifier with several preprocessing methods: up to date, only multinomial NB and SGD logistic regresion are considered."; } protected static AttributeSelection selector = null; protected AutoExpandVector<AttributeClassObserver> attributeObservers; public static IntOption numFeaturesOption = new IntOption("numFeatures", 'f', "The number of features to select", 10, 1, Integer.MAX_VALUE); public static IntOption fsmethodOption = new IntOption("fsMethod", 'm', "Infotheoretic method to be used in feature selection: 0. No method. 1. InfoGain 2. Symmetrical Uncertainty 3. OFSGD", 0, 0, 3); public static IntOption discmethodOption = new IntOption("discMethod", 'd', "Discretization method to be used: 0. No method. 1. PiD 2. IFFD 3. Online Chi-Merge 4. IDA 5. RebDiscretize", 5, 0, 5); public static IntOption winSizeOption = new IntOption("winSize", 'w', "Window size for model updates", 5000, 1, Integer.MAX_VALUE); public static IntOption thresholdOption = new IntOption("threshold", 't', "Threshold for initialization", 10000, 1, Integer.MAX_VALUE); public static IntOption decimalsOption = new IntOption("decimals", 'e', "Number of decimals to round", 3, 0, Integer.MAX_VALUE); public static IntOption maxLabelsOption = new IntOption("maxLabels", 'l', "Number of different labels to use in discretization", 10000, 10, Integer.MAX_VALUE); public IntOption numClassesOption = new IntOption("numClasses", 'c', "Number of classes for this problem (Online Chi-Merge)", 100, 1, Integer.MAX_VALUE); public IntOption baseClassifier = new IntOption("baseClassifier", 'b', "Base classifier to be used: 0. NB 1. LR (SGD Multiclass) 2. Hoeffding Tree", 2, 0, 2); protected static MOAAttributeEvaluator fselector = null; protected static MOADiscretize discretizer = null; protected int totalCount = 0, classified = 0, correctlyClassified = 0; protected Set<Integer> selectedFeatures = new HashSet<Integer>(); protected AbstractClassifier wrapperClassifier; //private double sumTime, sumTime2; public ReductionClassifier() { // TODO Auto-generated constructor stub if (baseClassifier.getValue() == 0) { wrapperClassifier = new NaiveBayes(); wrapperClassifier.resetLearningImpl(); } else if (baseClassifier.getValue() == 1) { SGDMultiClass tmp = new SGDMultiClass(); tmp.setLossFunction(1); wrapperClassifier = tmp; wrapperClassifier.resetLearningImpl(); } else { wrapperClassifier = new HoeffdingTree(); wrapperClassifier.resetLearningImpl(); } } @Override public void resetLearningImpl() { this.attributeObservers = new AutoExpandVector<AttributeClassObserver>(); totalCount = 0; classified = 0; correctlyClassified = 0; } @Override public void trainOnInstanceImpl(Instance inst) { Instance rinst = inst.copy(); // Update the FS evaluator (no selection is applied here) if (fsmethodOption.getValue() != 0) { if (fselector == null) { if (fsmethodOption.getValue() == 3) { fselector = new OFSGDAttributeEval(numFeaturesOption.getValue()); } else if (fsmethodOption.getValue() == 2 || fsmethodOption.getValue() == 1) { fselector = new IncrInfoThAttributeEval(fsmethodOption.getValue()); } else { //fselector = null; } } try { fselector.updateEvaluator(inst); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } // Update the discretization scheme, and apply it to the given instance //long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); if (discmethodOption.getValue() != 0) { if (discretizer == null) { if (discmethodOption.getValue() == 1) { discretizer = new PIDdiscretize(); } else if (discmethodOption.getValue() == 2) { discretizer = new IFFDdiscretize(); } else if (discmethodOption.getValue() == 3) { discretizer = new OCdiscretize(this.numClassesOption.getValue()); } else if (discmethodOption.getValue() == 4) { discretizer = new IDAdiscretize(); } else { discretizer = new LOFDiscretizer(winSizeOption.getValue(), thresholdOption.getValue(), decimalsOption.getValue(), maxLabelsOption.getValue()); } } else { discretizer.updateEvaluator(inst); if (totalCount == thresholdOption.getValue() + 1) wrapperClassifier.resetLearningImpl(); } //System.out.println("Number of new intervals: " + discretizer.getNumberIntervals()); rinst = discretizer.applyDiscretization(inst); } //sumTime += TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - evaluateStartTime); for (int i = 0; i < rinst.numAttributes() - 1; i++) { if (rinst.value(i) == -1) { System.out.println("Value changed"); rinst.setValue(i, 0); } } wrapperClassifier.trainOnInstance(rinst); totalCount++; //if(totalCount == 50000) //System.out.println("Total time: " + sumTime); } @Override public double[] getVotesForInstance(Instance inst) { // Feature selection process performed before Instance sinst = inst.copy(); if (fsmethodOption.getValue() != 0 && fselector != null) sinst = performFS(sinst); if (discmethodOption.getValue() != 0 && discretizer != null) sinst = discretizer.applyDiscretization(sinst); double[] finalVotes = wrapperClassifier.getVotesForInstance(sinst); double maxValue = Integer.MIN_VALUE; int maxIndex = Integer.MIN_VALUE; for (int i = 0; i < finalVotes.length; i++) { if (finalVotes[i] > maxValue) { maxIndex = i; maxValue = finalVotes[i]; } } if (maxIndex == inst.classIndex()) correctlyClassified++; classified++; return finalVotes; } @Override protected Measurement[] getModelMeasurementsImpl() { return null; } @Override public void getModelDescription(StringBuilder result, int indent) { StringUtils.appendIndented(result, indent, toString()); StringUtils.appendNewline(result); } @Override public boolean isRandomizable() { return false; } private Instance performFS(Instance rinst) { // Feature selection process performed before weka.core.Instance winst = new weka.core.DenseInstance(rinst.weight(), rinst.toDoubleArray()); if (fselector != null) { if (fselector.isUpdated() && totalCount % winSizeOption.getValue() == 0) { fselector.applySelection(); selector = new AttributeSelection(); Ranker ranker = new Ranker(); ranker.setNumToSelect(Math.min(numFeaturesOption.getValue(), winst.numAttributes() - 1)); selector.setEvaluator((ASEvaluation) fselector); selector.setSearch(ranker); ArrayList<Attribute> list = new ArrayList<Attribute>(); //ArrayList<Attribute> list = Collections.list(winst.enumerateAttributes()); //list.add(winst.classAttribute()); for (int i = 0; i < rinst.numAttributes(); i++) list.add(new Attribute(rinst.attribute(i).name(), i)); //ArrayList<Attribute> list = Collections.list(winst.enumerateAttributes()); //list.add(winst.classAttribute()); weka.core.Instances single = new weka.core.Instances("single", list, 1); single.setClassIndex(rinst.classIndex()); single.add(winst); try { selector.SelectAttributes(single); System.out.println("Selected features: " + selector.toResultsString()); selectedFeatures.clear(); for (int att : selector.selectedAttributes()) selectedFeatures.add(att); WekaToSamoaInstanceConverter convWS = new WekaToSamoaInstanceConverter(); return convWS.samoaInstance(selector.reduceDimensionality(winst)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } } return rinst; } }