Java tutorial
package edu.umbc.cs.maple.utils; //import org.apache.commons.math.stat.descriptive.SummaryStatistics; //import org.apache.commons.math.stat.descriptive.SummaryStatisticsImpl; import weka.classifiers.Classifier; import weka.core.Attribute; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; /** Various utility functions for Weka. * <p> * Copyright (c) 2008 Eric Eaton * <p> * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * <p> * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * <p> * You should have received a copy of the GNU General Public License * along with this program. If not, see http://www.gnu.org/licenses/. * * @author Eric Eaton (EricEaton@umbc.edu) <br> * University of Maryland Baltimore County * * @version 0.1 * */ public class WekaUtils { /** Take a certain percentage of a set of instances. * @param instances * @param percentage * @return a reduced set of instances according to the given percentage */ public static Instances trimInstances(Instances instances, double percentage) { int numInstancesToKeep = (int) Math.ceil(instances.numInstances() * percentage); return trimInstances(instances, numInstancesToKeep); } /** Take a certain number of a set of instances. * @param instances * @param numInstances the number of instances to keep * @return a reduced set of instances according to the given number to keep */ public static Instances trimInstances(Instances instances, int numInstances) { Instances trimmedInstances = new Instances(instances); for (int i = trimmedInstances.numInstances() - 1; i >= numInstances; i--) { trimmedInstances.delete(i); } return trimmedInstances; } /** Extract a particular subset of the instances. * @param instances * @param startIdx the start instance index * @param numInstancesToRetrieve the number of instances to retrieve * @return the specified subset of the instances. */ public static Instances subsetInstances(Instances instances, int startIdx, int numInstancesToRetrieve) { double possibleNumInstancesToRetrieve = instances.numInstances() - startIdx; if (numInstancesToRetrieve > possibleNumInstancesToRetrieve) { throw new IllegalArgumentException( "Cannot retrieve more than " + possibleNumInstancesToRetrieve + " instances."); } int endIdx = startIdx + numInstancesToRetrieve - 1; // delete all instance indices outside of [startIdx, endIdx] Instances subset = new Instances(instances); for (int i = subset.numInstances() - 1; i >= 0; i--) { if (i < startIdx || i > endIdx) subset.delete(i); } return subset; } /** Merge two instance sets. * @param instances1 * @param instances2 * @return the merged instance sets */ public static Instances mergeInstances(Instances instances1, Instances instances2) { if (instances1 == null) return instances2; if (instances2 == null) return instances1; if (!instances1.checkInstance(instances2.firstInstance())) throw new IllegalArgumentException("The instance sets are incompatible."); Instances mergedInstances = new Instances(instances1); Instances tempInstances = new Instances(instances2); for (int i = 0; i < tempInstances.numInstances(); i++) { mergedInstances.add(tempInstances.instance(i)); } return mergedInstances; } /** * Converts an instance to a feature vector excluding the class attribute. * @param instance The instance. * @return A vector representation of the instance excluding the class attribute */ public static double[] instanceToDoubleArray(Instance instance) { double[] vector = new double[(instance.classIndex() != -1) ? instance.numAttributes() - 1 : instance.numAttributes()]; double[] instanceDoubleArray = instance.toDoubleArray(); int attIdx = 0; for (int i = 0; i < vector.length; i++) { if (i == instance.classIndex()) { attIdx++; } vector[i] = instanceDoubleArray[attIdx++]; } return vector; } /** * Converts a set of instances to an array of vectors * @param instances The set of instances. * @return The array of feature vectors. */ public static double[][] instancesToDoubleArrays(Instances instances) { double[][] vectors = new double[instances.numInstances()][]; for (int instIdx = 0; instIdx < instances.numInstances(); instIdx++) { vectors[instIdx] = instanceToDoubleArray(instances.instance(instIdx)); } return vectors; } /** Uses the given model to predict the classes of the data. * @param model * @param data * @return An array of the class predictions. */ public static int[] predictClasses(Classifier model, Instances data) { int[] predictions = new int[data.numInstances()]; int numInstances = data.numInstances(); for (int instIdx = 0; instIdx < numInstances; instIdx++) { try { predictions[instIdx] = (int) model.classifyInstance(data.instance(instIdx)); } catch (Exception e) { predictions[instIdx] = -1; } } return predictions; } /** Gets the class labels for a set of instances. * @param data * @return a vector of the class labels for the data set, with one entry per instance */ public static int[] getLabels(Instances data) { int[] classLabels = new int[data.numInstances()]; for (int instIdx = 0; instIdx < classLabels.length; instIdx++) { classLabels[instIdx] = (int) data.instance(instIdx).classValue(); } return classLabels; } /** Gets the class values for a set of instances. * @param data * @return a vector of the class values for the data set, with one entry per instance */ public static double[] getClassValues(Instances data) { double[] classLabels = new double[data.numInstances()]; for (int instIdx = 0; instIdx < classLabels.length; instIdx++) { classLabels[instIdx] = data.instance(instIdx).classValue(); } return classLabels; } /* public static SummaryStatistics getPredictionStats(Instances data, int[] predictedClasses) { SummaryStatistics stats = new SummaryStatisticsImpl(); for (int instIdx = 0; instIdx < data.numInstances(); instIdx++) { Instance instance = data.instance(instIdx); if (instance.classValue() == (double) predictedClasses[instIdx]) stats.addValue(1); else stats.addValue(0); } return stats; } */ /** Converts the instances in the given dataset to binary, setting the specified labels to positive. * Note this method is destructive to data, directly modifying its contents. * @param data the multiclass dataset to be converted to binary. * @param positiveClassValue the class value to treat as positive. */ public static void convertMulticlassToBinary(Instances data, String positiveClassValue) { // ensure that data is nominal if (!data.classAttribute().isNominal()) throw new IllegalArgumentException("Instances must have a nominal class."); // create the new class attribute FastVector newClasses = new FastVector(2); newClasses.addElement("Y"); newClasses.addElement("N"); Attribute newClassAttribute = new Attribute("class", newClasses); // alter the class attribute to be binary int newClassAttIdx = data.classIndex(); data.insertAttributeAt(newClassAttribute, newClassAttIdx); int classAttIdx = data.classIndex(); // set the instances classes to be binary, with the labels [Y,N] (indices 0 and 1 respectively) int numInstances = data.numInstances(); for (int instIdx = 0; instIdx < numInstances; instIdx++) { Instance inst = data.instance(instIdx); if (inst.stringValue(classAttIdx).equals(positiveClassValue)) { inst.setValue(newClassAttIdx, 0); // set it to the first class, which will be Y } else { inst.setValue(newClassAttIdx, 1); // set it to the second class, which will be 0 } } // switch the class index to the new class and delete the old class data.setClassIndex(newClassAttIdx); data.deleteAttributeAt(classAttIdx); // alter the dataset name data.setRelationName(data.relationName() + "-" + positiveClassValue); } /** Determines whether a data set has equal class priors. * @param data * @return whether the data set has equal class priors */ public static boolean equalClassPriors(Instances data) { double[] counts = new double[data.numClasses()]; int numInstances = data.numInstances(); for (int i = 0; i < numInstances; i++) { Instance inst = data.instance(i); int classValueIdx = (int) Math.round(inst.classValue()); counts[classValueIdx] = counts[classValueIdx] + 1; } // compute the mean double meanCount = MathUtils.sum(counts) / counts.length; double[] meanArray = new double[counts.length]; for (int i = 0; i < meanArray.length; i++) { meanArray[i] = meanCount; } // compute the rmse double rmse = MathUtils.rmse(meanArray, counts); // compute 2.5% of the possible double deviationAllowed = Math.ceil(0.025 * meanCount); if (rmse <= deviationAllowed) return true; else return false; } /** Gets the weights of each instance in a dataset as an array. * @param data the dataset of instances * @return the weights of the instances as an array. */ public static double[] getWeights(Instances data) { int numInstances = data.numInstances(); double[] weights = new double[numInstances]; for (int instIdx = 0; instIdx < numInstances; instIdx++) { weights[instIdx] = data.instance(instIdx).weight(); } return weights; } /** Defines the format of the SVMLight labels */ public enum SVMLightLabelFormat { CLASSIFICATION, REGRESSION } /** Converts a set of instances to svm-light format * @param data the weka instances * @return the weka instances in svm-light format */ public static String arffToSVMLight(Instances data, SVMLightLabelFormat labelFormat) { if (labelFormat == SVMLightLabelFormat.CLASSIFICATION && data.numClasses() != 2) { throw new IllegalArgumentException( "SVM-light classification label format requires that the data contain only two classes."); } String str = ""; String endline = System.getProperty("line.separator"); int numInstances = data.numInstances(); int numAttributes = data.numAttributes(); int classAttIdx = data.classIndex(); for (int instIdx = 0; instIdx < numInstances; instIdx++) { Instance inst = data.instance(instIdx); // convert the instance label if (labelFormat == SVMLightLabelFormat.CLASSIFICATION) { str += (inst.classValue() == 0) ? "-1" : "1"; } else { str += inst.classValue(); } str += " "; // convert each feature for (int attIdx = 0; attIdx < numAttributes; attIdx++) { // skip the class attribute if (attIdx == classAttIdx) continue; str += (attIdx + 1) + ":" + inst.value(attIdx) + " "; } // append the instance info string str += "# " + instIdx; str += endline; } return str; } /** Converts a set of instances to svm-light format * @param data the weka instances * @return the weka instances in svm-light format */ public static String arffToSVMLight(Instance data, SVMLightLabelFormat labelFormat) { if (labelFormat == SVMLightLabelFormat.CLASSIFICATION && data.numClasses() != 2) { throw new IllegalArgumentException( "SVM-light classification label format requires that the data contain only two classes."); } String str = ""; String endline = System.getProperty("line.separator"); int numAttributes = data.numAttributes(); int classAttIdx = data.classIndex(); // convert the instance label if (labelFormat == SVMLightLabelFormat.CLASSIFICATION) { str += (data.classValue() == 0) ? "-1" : "1"; } else { str += data.classValue(); } str += " "; // convert each feature for (int attIdx = 0; attIdx < numAttributes; attIdx++) { // skip the class attribute if (attIdx == classAttIdx) continue; str += (attIdx + 1) + ":" + data.value(attIdx) + " "; } // append the instance info string str += "#"; str += endline; return str; } }