edu.umbc.cs.maple.utils.WekaUtils.java Source code

Java tutorial

Introduction

Here is the source code for edu.umbc.cs.maple.utils.WekaUtils.java

Source

package edu.umbc.cs.maple.utils;

//import org.apache.commons.math.stat.descriptive.SummaryStatistics;
//import org.apache.commons.math.stat.descriptive.SummaryStatisticsImpl;

import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/** Various utility functions for Weka.
 * <p>
 * Copyright (c) 2008 Eric Eaton
 * <p>
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * <p>
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * <p>
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see http://www.gnu.org/licenses/.
 * 
 * @author Eric Eaton (EricEaton@umbc.edu) <br>
 *          University of Maryland Baltimore County
 * 
 * @version 0.1
 *
 */
public class WekaUtils {

    /** Take a certain percentage of a set of instances.
     * @param instances
     * @param percentage
     * @return a reduced set of instances according to the given percentage
     */
    public static Instances trimInstances(Instances instances, double percentage) {
        int numInstancesToKeep = (int) Math.ceil(instances.numInstances() * percentage);
        return trimInstances(instances, numInstancesToKeep);
    }

    /** Take a certain number of a set of instances.
     * @param instances
     * @param numInstances the number of instances to keep
     * @return a reduced set of instances according to the given number to keep
     */
    public static Instances trimInstances(Instances instances, int numInstances) {
        Instances trimmedInstances = new Instances(instances);
        for (int i = trimmedInstances.numInstances() - 1; i >= numInstances; i--) {
            trimmedInstances.delete(i);
        }
        return trimmedInstances;
    }

    /** Extract a particular subset of the instances.
     * @param instances
     * @param startIdx the start instance index
     * @param numInstancesToRetrieve the number of instances to retrieve
     * @return the specified subset of the instances.
     */
    public static Instances subsetInstances(Instances instances, int startIdx, int numInstancesToRetrieve) {
        double possibleNumInstancesToRetrieve = instances.numInstances() - startIdx;
        if (numInstancesToRetrieve > possibleNumInstancesToRetrieve) {
            throw new IllegalArgumentException(
                    "Cannot retrieve more than " + possibleNumInstancesToRetrieve + " instances.");
        }

        int endIdx = startIdx + numInstancesToRetrieve - 1;

        // delete all instance indices outside of [startIdx, endIdx]
        Instances subset = new Instances(instances);
        for (int i = subset.numInstances() - 1; i >= 0; i--) {
            if (i < startIdx || i > endIdx)
                subset.delete(i);
        }

        return subset;
    }

    /** Merge two instance sets.
     * @param instances1
     * @param instances2
     * @return the merged instance sets
     */
    public static Instances mergeInstances(Instances instances1, Instances instances2) {
        if (instances1 == null)
            return instances2;
        if (instances2 == null)
            return instances1;
        if (!instances1.checkInstance(instances2.firstInstance()))
            throw new IllegalArgumentException("The instance sets are incompatible.");
        Instances mergedInstances = new Instances(instances1);
        Instances tempInstances = new Instances(instances2);
        for (int i = 0; i < tempInstances.numInstances(); i++) {
            mergedInstances.add(tempInstances.instance(i));
        }
        return mergedInstances;
    }

    /**
     * Converts an instance to a feature vector excluding the class attribute.
     * @param instance The instance.
     * @return A vector representation of the instance excluding the class attribute
     */
    public static double[] instanceToDoubleArray(Instance instance) {
        double[] vector = new double[(instance.classIndex() != -1) ? instance.numAttributes() - 1
                : instance.numAttributes()];
        double[] instanceDoubleArray = instance.toDoubleArray();
        int attIdx = 0;
        for (int i = 0; i < vector.length; i++) {
            if (i == instance.classIndex()) {
                attIdx++;
            }
            vector[i] = instanceDoubleArray[attIdx++];
        }
        return vector;
    }

    /**
     * Converts a set of instances to an array of vectors
     * @param instances The set of instances.
     * @return The array of feature vectors.
     */
    public static double[][] instancesToDoubleArrays(Instances instances) {
        double[][] vectors = new double[instances.numInstances()][];
        for (int instIdx = 0; instIdx < instances.numInstances(); instIdx++) {
            vectors[instIdx] = instanceToDoubleArray(instances.instance(instIdx));
        }
        return vectors;
    }

    /** Uses the given model to predict the classes of the data.
     * @param model
     * @param data
     * @return An array of the class predictions.
     */
    public static int[] predictClasses(Classifier model, Instances data) {
        int[] predictions = new int[data.numInstances()];
        int numInstances = data.numInstances();
        for (int instIdx = 0; instIdx < numInstances; instIdx++) {
            try {
                predictions[instIdx] = (int) model.classifyInstance(data.instance(instIdx));
            } catch (Exception e) {
                predictions[instIdx] = -1;
            }
        }
        return predictions;
    }

    /** Gets the class labels for a set of instances.
     * @param data
     * @return a vector of the class labels for the data set, with one entry per instance
     */
    public static int[] getLabels(Instances data) {
        int[] classLabels = new int[data.numInstances()];
        for (int instIdx = 0; instIdx < classLabels.length; instIdx++) {
            classLabels[instIdx] = (int) data.instance(instIdx).classValue();
        }
        return classLabels;
    }

    /** Gets the class values for a set of instances.
     * @param data
     * @return a vector of the class values for the data set, with one entry per instance
     */
    public static double[] getClassValues(Instances data) {
        double[] classLabels = new double[data.numInstances()];
        for (int instIdx = 0; instIdx < classLabels.length; instIdx++) {
            classLabels[instIdx] = data.instance(instIdx).classValue();
        }
        return classLabels;
    }

    /*
       public static SummaryStatistics getPredictionStats(Instances data, int[] predictedClasses) {
          SummaryStatistics stats = new SummaryStatisticsImpl();
          for (int instIdx = 0; instIdx < data.numInstances(); instIdx++) {
     Instance instance = data.instance(instIdx);
     if (instance.classValue() == (double) predictedClasses[instIdx])
        stats.addValue(1);
     else
        stats.addValue(0);
          }
          return stats;
       }
    */

    /** Converts the instances in the given dataset to binary, setting the specified labels to positive.
     * Note this method is destructive to data, directly modifying its contents.
     * @param data the multiclass dataset to be converted to binary.
     * @param positiveClassValue the class value to treat as positive.
     */
    public static void convertMulticlassToBinary(Instances data, String positiveClassValue) {

        // ensure that data is nominal
        if (!data.classAttribute().isNominal())
            throw new IllegalArgumentException("Instances must have a nominal class.");

        // create the new class attribute
        FastVector newClasses = new FastVector(2);
        newClasses.addElement("Y");
        newClasses.addElement("N");
        Attribute newClassAttribute = new Attribute("class", newClasses);

        // alter the class attribute to be binary
        int newClassAttIdx = data.classIndex();
        data.insertAttributeAt(newClassAttribute, newClassAttIdx);
        int classAttIdx = data.classIndex();

        // set the instances classes to be binary, with the labels [Y,N] (indices 0 and 1 respectively)
        int numInstances = data.numInstances();
        for (int instIdx = 0; instIdx < numInstances; instIdx++) {
            Instance inst = data.instance(instIdx);
            if (inst.stringValue(classAttIdx).equals(positiveClassValue)) {
                inst.setValue(newClassAttIdx, 0); // set it to the first class, which will be Y
            } else {
                inst.setValue(newClassAttIdx, 1); // set it to the second class, which will be 0
            }
        }

        // switch the class index to the new class and delete the old class
        data.setClassIndex(newClassAttIdx);
        data.deleteAttributeAt(classAttIdx);

        // alter the dataset name
        data.setRelationName(data.relationName() + "-" + positiveClassValue);
    }

    /** Determines whether a data set has equal class priors.
     * @param data
     * @return whether the data set has equal class priors
     */
    public static boolean equalClassPriors(Instances data) {
        double[] counts = new double[data.numClasses()];
        int numInstances = data.numInstances();
        for (int i = 0; i < numInstances; i++) {
            Instance inst = data.instance(i);
            int classValueIdx = (int) Math.round(inst.classValue());
            counts[classValueIdx] = counts[classValueIdx] + 1;
        }

        // compute the mean
        double meanCount = MathUtils.sum(counts) / counts.length;
        double[] meanArray = new double[counts.length];
        for (int i = 0; i < meanArray.length; i++) {
            meanArray[i] = meanCount;
        }

        // compute the rmse
        double rmse = MathUtils.rmse(meanArray, counts);

        // compute 2.5% of the possible 
        double deviationAllowed = Math.ceil(0.025 * meanCount);

        if (rmse <= deviationAllowed)
            return true;
        else
            return false;
    }

    /** Gets the weights of each instance in a dataset as an array.
     * @param data the dataset of instances
     * @return the weights of the instances as an array.
     */
    public static double[] getWeights(Instances data) {
        int numInstances = data.numInstances();
        double[] weights = new double[numInstances];
        for (int instIdx = 0; instIdx < numInstances; instIdx++) {
            weights[instIdx] = data.instance(instIdx).weight();
        }
        return weights;
    }

    /** Defines the format of the SVMLight labels */
    public enum SVMLightLabelFormat {
        CLASSIFICATION, REGRESSION
    }

    /** Converts a set of instances to svm-light format
     * @param data the weka instances
     * @return the weka instances in svm-light format
     */
    public static String arffToSVMLight(Instances data, SVMLightLabelFormat labelFormat) {

        if (labelFormat == SVMLightLabelFormat.CLASSIFICATION && data.numClasses() != 2) {
            throw new IllegalArgumentException(
                    "SVM-light classification label format requires that the data contain only two classes.");
        }

        String str = "";
        String endline = System.getProperty("line.separator");

        int numInstances = data.numInstances();
        int numAttributes = data.numAttributes();
        int classAttIdx = data.classIndex();

        for (int instIdx = 0; instIdx < numInstances; instIdx++) {

            Instance inst = data.instance(instIdx);

            // convert the instance label
            if (labelFormat == SVMLightLabelFormat.CLASSIFICATION) {
                str += (inst.classValue() == 0) ? "-1" : "1";
            } else {
                str += inst.classValue();
            }

            str += " ";

            // convert each feature
            for (int attIdx = 0; attIdx < numAttributes; attIdx++) {
                // skip the class attribute
                if (attIdx == classAttIdx)
                    continue;
                str += (attIdx + 1) + ":" + inst.value(attIdx) + " ";
            }

            // append the instance info string
            str += "# " + instIdx;

            str += endline;
        }

        return str;
    }

    /** Converts a set of instances to svm-light format
     * @param data the weka instances
     * @return the weka instances in svm-light format
     */
    public static String arffToSVMLight(Instance data, SVMLightLabelFormat labelFormat) {

        if (labelFormat == SVMLightLabelFormat.CLASSIFICATION && data.numClasses() != 2) {
            throw new IllegalArgumentException(
                    "SVM-light classification label format requires that the data contain only two classes.");
        }

        String str = "";
        String endline = System.getProperty("line.separator");

        int numAttributes = data.numAttributes();
        int classAttIdx = data.classIndex();

        // convert the instance label
        if (labelFormat == SVMLightLabelFormat.CLASSIFICATION) {
            str += (data.classValue() == 0) ? "-1" : "1";
        } else {
            str += data.classValue();
        }

        str += " ";

        // convert each feature
        for (int attIdx = 0; attIdx < numAttributes; attIdx++) {
            // skip the class attribute
            if (attIdx == classAttIdx)
                continue;
            str += (attIdx + 1) + ":" + data.value(attIdx) + " ";
        }

        // append the instance info string
        str += "#";

        str += endline;

        return str;
    }
}