Java tutorial
/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * AddClassification.java * Copyright (C) 2006-2012 University of Waikato, Hamilton, New Zealand */ package weka.filters.supervised.attribute; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.ObjectInputStream; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.Vector; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.misc.InputMappedClassifier; import weka.core.*; import weka.filters.SimpleBatchFilter; /** * <!-- globalinfo-start --> A filter for adding the classification, the class * distribution and an error flag to a dataset with a classifier. The classifier * is either trained on the data itself or provided as serialized model. * <p/> * <!-- globalinfo-end --> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -D * Turns on output of debugging information. * </pre> * * <pre> * -W <classifier specification> * Full class name of classifier to use, followed * by scheme options. eg: * "weka.classifiers.bayes.NaiveBayes -D" * (default: weka.classifiers.rules.ZeroR) * </pre> * * <pre> * -serialized <file> * Instead of training a classifier on the data, one can also provide * a serialized model and use that for tagging the data. * </pre> * * <pre> * -classification * Adds an attribute with the actual classification. * (default: off) * </pre> * * <pre> * -remove-old-class * Removes the old class attribute. * (default: off) * </pre> * * <pre> * -distribution * Adds attributes with the distribution for all classes * (for numeric classes this will be identical to the attribute * output with '-classification'). * (default: off) * </pre> * * <pre> * -error * Adds an attribute indicating whether the classifier output * a wrong classification (for numeric classes this is the numeric * difference). * (default: off) * </pre> * * <!-- options-end --> * * @author fracpete (fracpete at waikato dot ac dot nz) * @version $Revision$ */ public class AddClassification extends SimpleBatchFilter implements WeightedAttributesHandler, WeightedInstancesHandler { /** for serialization. */ private static final long serialVersionUID = -1931467132568441909L; /** The classifier template used to do the classification. */ protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR(); /** The file from which to load a serialized classifier. */ protected File m_SerializedClassifierFile = new File(System.getProperty("user.dir")); /** The actual classifier used to do the classification. */ protected Classifier m_ActualClassifier = null; /** the header of the file the serialized classifier was trained with. */ protected Instances m_SerializedHeader = null; /** whether to output the classification. */ protected boolean m_OutputClassification = false; /** whether to remove the old class attribute. */ protected boolean m_RemoveOldClass = false; /** whether to output the class distribution. */ protected boolean m_OutputDistribution = false; /** whether to output the error flag. */ protected boolean m_OutputErrorFlag = false; /** * Returns a string describing this filter. * * @return a description of the filter suitable for displaying in the * explorer/experimenter gui */ @Override public String globalInfo() { return "A filter for adding the classification, the class distribution and " + "an error flag to a dataset with a classifier. The classifier is " + "either trained on the data itself or provided as serialized model."; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration<Option> listOptions() { Vector<Option> result = new Vector<Option>(); result.addElement(new Option("\tFull class name of classifier to use, followed\n" + "\tby scheme options. eg:\n" + "\t\t\"weka.classifiers.bayes.NaiveBayes -D\"\n" + "\t(default: weka.classifiers.rules.ZeroR)", "W", 1, "-W <classifier specification>")); result.addElement(new Option( "\tInstead of training a classifier on the data, one can also provide\n" + "\ta serialized model and use that for tagging the data.", "serialized", 1, "-serialized <file>")); result.addElement(new Option("\tAdds an attribute with the actual classification.\n" + "\t(default: off)", "classification", 0, "-classification")); result.addElement(new Option("\tRemoves the old class attribute.\n" + "\t(default: off)", "remove-old-class", 0, "-remove-old-class")); result.addElement(new Option( "\tAdds attributes with the distribution for all classes \n" + "\t(for numeric classes this will be identical to the attribute \n" + "\toutput with '-classification').\n" + "\t(default: off)", "distribution", 0, "-distribution")); result.addElement(new Option("\tAdds an attribute indicating whether the classifier output \n" + "\ta wrong classification (for numeric classes this is the numeric \n" + "\tdifference).\n" + "\t(default: off)", "error", 0, "-error")); result.addAll(Collections.list(super.listOptions())); return result.elements(); } /** * Parses the options for this object. * <p/> * * <!-- options-start --> Valid options are: * <p/> * * <pre> * -D * Turns on output of debugging information. * </pre> * * <pre> * -W <classifier specification> * Full class name of classifier to use, followed * by scheme options. eg: * "weka.classifiers.bayes.NaiveBayes -D" * (default: weka.classifiers.rules.ZeroR) * </pre> * * <pre> * -serialized <file> * Instead of training a classifier on the data, one can also provide * a serialized model and use that for tagging the data. * </pre> * * <pre> * -classification * Adds an attribute with the actual classification. * (default: off) * </pre> * * <pre> * -remove-old-class * Removes the old class attribute. * (default: off) * </pre> * * <pre> * -distribution * Adds attributes with the distribution for all classes * (for numeric classes this will be identical to the attribute * output with '-classification'). * (default: off) * </pre> * * <pre> * -error * Adds an attribute indicating whether the classifier output * a wrong classification (for numeric classes this is the numeric * difference). * (default: off) * </pre> * * <!-- options-end --> * * @param options the options to use * @throws Exception if setting of options fails */ @Override public void setOptions(String[] options) throws Exception { String tmpStr; String[] tmpOptions; File file; boolean serializedModel; setOutputClassification(Utils.getFlag("classification", options)); setRemoveOldClass(Utils.getFlag("remove-old-class", options)); setOutputDistribution(Utils.getFlag("distribution", options)); setOutputErrorFlag(Utils.getFlag("error", options)); serializedModel = false; tmpStr = Utils.getOption("serialized", options); if (tmpStr.length() != 0) { file = new File(tmpStr); if (!file.exists()) { throw new FileNotFoundException("File '" + file.getAbsolutePath() + "' not found!"); } if (file.isDirectory()) { throw new FileNotFoundException( "'" + file.getAbsolutePath() + "' points to a directory not a file!"); } setSerializedClassifierFile(file); serializedModel = true; } else { setSerializedClassifierFile(null); } if (!serializedModel) { tmpStr = Utils.getOption('W', options); if (tmpStr.length() == 0) { tmpStr = weka.classifiers.rules.ZeroR.class.getName(); } tmpOptions = Utils.splitOptions(tmpStr); if (tmpOptions.length == 0) { throw new Exception("Invalid classifier specification string"); } tmpStr = tmpOptions[0]; tmpOptions[0] = ""; setClassifier(AbstractClassifier.forName(tmpStr, tmpOptions)); } super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ @Override public String[] getOptions() { Vector<String> result = new Vector<String>(); if (getOutputClassification()) { result.add("-classification"); } if (getRemoveOldClass()) { result.add("-remove-old-class"); } if (getOutputDistribution()) { result.add("-distribution"); } if (getOutputErrorFlag()) { result.add("-error"); } File file = getSerializedClassifierFile(); if ((file != null) && (!file.isDirectory())) { result.add("-serialized"); result.add(file.getAbsolutePath()); } else { result.add("-W"); result.add(getClassifierSpec()); } Collections.addAll(result, super.getOptions()); return result.toArray(new String[result.size()]); } /** * resets the filter, i.e., m_ActualClassifier to null. * * @see #m_ActualClassifier */ @Override protected void reset() { super.reset(); m_ActualClassifier = null; m_SerializedHeader = null; } /** * Returns the actual classifier to use, either from the serialized model or * the one specified by the user. * * @return the classifier to use, null in case of an error */ protected Classifier getActualClassifier() { File file; ObjectInputStream ois; if (m_ActualClassifier == null) { try { file = getSerializedClassifierFile(); if (!file.isDirectory()) { // ois = new ObjectInputStream(new FileInputStream(file)); ois = SerializationHelper.getObjectInputStream(new FileInputStream(file)); m_ActualClassifier = (Classifier) ois.readObject(); m_SerializedHeader = null; // let's see whether there's an Instances header stored as well try { m_SerializedHeader = (Instances) ois.readObject(); } catch (Exception e) { // ignored m_SerializedHeader = null; } ois.close(); } else { m_ActualClassifier = AbstractClassifier.makeCopy(m_Classifier); } } catch (Exception e) { m_ActualClassifier = null; System.err.println("Failed to instantiate classifier:"); e.printStackTrace(); } } return m_ActualClassifier; } /** * Need to override this to deal with InputMappedClassifier case. * (If InputMappedClassifier is applied to test data that is different in some important aspects: we * need to test capabilities with respect to format of data used to train the classifier.) * * @param instanceInfo the data to test * @throws Exception if the test fails */ protected void testInputFormat(Instances instanceInfo) throws Exception { Classifier classifier = getActualClassifier(); if (classifier instanceof InputMappedClassifier) { Instances trainingData = ((InputMappedClassifier) classifier) .getModelHeader(new Instances(instanceInfo, 0)); getCapabilities(trainingData).testWithFail(trainingData); } else { getCapabilities(instanceInfo).testWithFail(instanceInfo); } } /** * Returns the Capabilities of this filter. * * @return the capabilities of this object * @see Capabilities */ @Override public Capabilities getCapabilities() { Capabilities result; if (getActualClassifier() == null) { result = super.getCapabilities(); result.disableAll(); } else { result = getActualClassifier().getCapabilities(); } result.setMinimumNumberInstances(0); return result; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String classifierTipText() { return "The classifier to use for classification."; } /** * Sets the classifier to classify instances with. * * @param value The classifier to be used (with its options set). */ public void setClassifier(Classifier value) { m_Classifier = value; } /** * Gets the classifier used by the filter. * * @return The classifier to be used. */ public Classifier getClassifier() { return m_Classifier; } /** * Gets the classifier specification string, which contains the class name of * the classifier and any options to the classifier. * * @return the classifier string. */ protected String getClassifierSpec() { String result; Classifier c; c = getClassifier(); result = c.getClass().getName(); if (c instanceof OptionHandler) { result += " " + Utils.joinOptions(((OptionHandler) c).getOptions()); } return result; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String serializedClassifierFileTipText() { return "A file containing the serialized model of a trained classifier."; } /** * Gets the file pointing to a serialized, trained classifier. If it is null * or pointing to a directory it will not be used. * * @return the file the serialized, trained classifier is located in */ public File getSerializedClassifierFile() { return m_SerializedClassifierFile; } /** * Sets the file pointing to a serialized, trained classifier. If the argument * is null, doesn't exist or pointing to a directory, then the value is * ignored. * * @param value the file pointing to the serialized, trained classifier */ public void setSerializedClassifierFile(File value) { if ((value == null) || (!value.exists())) { value = new File(System.getProperty("user.dir")); } m_SerializedClassifierFile = value; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String outputClassificationTipText() { return "Whether to add an attribute with the actual classification."; } /** * Get whether the classifiction of the classifier is output. * * @return true if the classification of the classifier is output. */ public boolean getOutputClassification() { return m_OutputClassification; } /** * Set whether the classification of the classifier is output. * * @param value whether the classification of the classifier is output. */ public void setOutputClassification(boolean value) { m_OutputClassification = value; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String removeOldClassTipText() { return "Whether to remove the old class attribute."; } /** * Get whether the old class attribute is removed. * * @return true if the old class attribute is removed. */ public boolean getRemoveOldClass() { return m_RemoveOldClass; } /** * Set whether the old class attribute is removed. * * @param value whether the old class attribute is removed. */ public void setRemoveOldClass(boolean value) { m_RemoveOldClass = value; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String outputDistributionTipText() { return "Whether to add attributes with the distribution for all classes " + "(for numeric classes this will be identical to the attribute output " + "with 'outputClassification')."; } /** * Get whether the classifiction of the classifier is output. * * @return true if the distribution of the classifier is output. */ public boolean getOutputDistribution() { return m_OutputDistribution; } /** * Set whether the Distribution of the classifier is output. * * @param value whether the distribution of the classifier is output. */ public void setOutputDistribution(boolean value) { m_OutputDistribution = value; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String outputErrorFlagTipText() { return "Whether to add an attribute indicating whether the classifier output " + "a wrong classification (for numeric classes this is the numeric " + "difference)."; } /** * Get whether the classifiction of the classifier is output. * * @return true if the classification of the classifier is output. */ public boolean getOutputErrorFlag() { return m_OutputErrorFlag; } /** * Set whether the classification of the classifier is output. * * @param value whether the classification of the classifier is output. */ public void setOutputErrorFlag(boolean value) { m_OutputErrorFlag = value; } /** * Determines the output format based on the input format and returns this. In * case the output format cannot be returned immediately, i.e., * immediateOutputFormat() returns false, then this method will be called from * batchFinished(). * * @param inputFormat the input format to base the output format on * @return the output format * @throws Exception in case the determination goes wrong * @see #hasImmediateOutputFormat() * @see #batchFinished() */ @Override protected Instances determineOutputFormat(Instances inputFormat) throws Exception { Instances result; int i; ArrayList<String> values; int classindex; classindex = -1; // Need to get actual class attribute from saved model if we are working with a saved model and it // is an InputMappedClassifier. Attribute classAttribute = inputFormat.classIndex() >= 0 ? inputFormat.classAttribute() : null; Classifier classifier = getActualClassifier(); if (!getSerializedClassifierFile().isDirectory()) { if (classifier instanceof InputMappedClassifier) { classAttribute = ((InputMappedClassifier) classifier).getModelHeader(new Instances(inputFormat, 0)) .classAttribute(); } } else { if ((classAttribute == null) && (!(classifier instanceof InputMappedClassifier))) { throw new IllegalArgumentException( "AddClassification: class must be set if InputMappedClassifier is not used."); } } // copy old attributes ArrayList<Attribute> atts = new ArrayList<Attribute>(); for (i = 0; i < inputFormat.numAttributes(); i++) { // remove class? if ((i == inputFormat.classIndex()) && (getRemoveOldClass())) { continue; } // record class index if (i == inputFormat.classIndex()) { classindex = i; } atts.add((Attribute) inputFormat.attribute(i).copy()); } // add new attributes // 1. classification? if (getOutputClassification()) { // if old class got removed, use this one if (classindex == -1) { classindex = atts.size(); } atts.add(classAttribute.copy("classification")); } // 2. distribution? if (getOutputDistribution()) { if (classAttribute.isNominal()) { for (i = 0; i < classAttribute.numValues(); i++) { atts.add(new Attribute("distribution_" + classAttribute.value(i))); } } else { atts.add(new Attribute("distribution")); } } // 2. error flag? if (getOutputErrorFlag()) { if (classAttribute.isNominal()) { values = new ArrayList<String>(); values.add("no"); values.add("yes"); atts.add(new Attribute("error", values)); } else { atts.add(new Attribute("error")); } } // generate new header result = new Instances(inputFormat.relationName(), atts, 0); result.setClassIndex(classindex); return result; } /** * Processes the given data (may change the provided dataset) and returns the * modified version. This method is called in batchFinished(). * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong * @see #batchFinished() */ @Override protected Instances process(Instances instances) throws Exception { Instances result; double[] newValues; double[] oldValues; int i; int n; Instance newInstance; Instance oldInstance; double[] distribution; // load or train classifier if (!isFirstBatchDone()) { getActualClassifier(); if (!getSerializedClassifierFile().isDirectory()) { // same dataset format? if ((m_SerializedHeader != null) && (!m_SerializedHeader.equalHeaders(instances)) && (!(m_ActualClassifier instanceof InputMappedClassifier))) { throw new WekaException("Training header of classifier and filter dataset don't match:\n" + m_SerializedHeader.equalHeadersMsg(instances)); } } else { m_ActualClassifier.buildClassifier(instances); } } result = getOutputFormat(); // traverse all instances for (i = 0; i < instances.numInstances(); i++) { oldInstance = instances.instance(i); oldValues = oldInstance.toDoubleArray(); newValues = new double[result.numAttributes()]; // copy values int start = 0; for (int j = 0; j < oldValues.length; j++) { // remove class? if ((j == inputFormatPeek().classIndex()) && (getRemoveOldClass())) { continue; } newValues[start++] = oldValues[j]; } // add new values: // 1. classification? if (getOutputClassification()) { newValues[start] = m_ActualClassifier.classifyInstance(oldInstance); start++; } // 2. distribution? if (getOutputDistribution()) { distribution = m_ActualClassifier.distributionForInstance(oldInstance); for (n = 0; n < distribution.length; n++) { newValues[start] = distribution[n]; start++; } } // 3. error flag? if (getOutputErrorFlag()) { Instance inst = oldInstance; if (m_ActualClassifier instanceof InputMappedClassifier) { inst = ((InputMappedClassifier) m_ActualClassifier).constructMappedInstance(inst); } if (instances.classIndex() < 0) { newValues[start] = Utils.missingValue(); } else if (result.classAttribute().isNominal()) { if (inst.classValue() == m_ActualClassifier.classifyInstance(oldInstance)) { newValues[start] = 0; } else { newValues[start] = 1; } } else { newValues[start] = m_ActualClassifier.classifyInstance(oldInstance) - inst.classValue(); } start++; } // create new instance if (oldInstance instanceof SparseInstance) { newInstance = new SparseInstance(oldInstance.weight(), newValues); } else { newInstance = new DenseInstance(oldInstance.weight(), newValues); } // copy string/relational values from input to output copyValues(newInstance, false, oldInstance.dataset(), outputFormatPeek()); result.add(newInstance); } return result; } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision$"); } /** * runs the filter with the given arguments. * * @param args the commandline arguments */ public static void main(String[] args) { runFilter(new AddClassification(), args); } }