ml.dataprocess.CorrelationAttributeEval.java Source code

Java tutorial

Introduction

Here is the source code for ml.dataprocess.CorrelationAttributeEval.java

Source

package ml.dataprocess;
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    CorrelationAttributeEval.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;

import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.AttributeEvaluator;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/**
 <!-- globalinfo-start -->
 * CorrelationAttributeEval :<br/>
 * <br/>
 * Evaluates the worth of an attribute by measuring the correlation (Pearson's) between it and the class.<br/>
 * <br/>
 * Nominal attributes are considered on a value by value basis by treating each value as an indicator. An overall correlation for a nominal attribute is arrived at via a weighted average.<br/>
 * <p/>
 <!-- globalinfo-end -->
 *
 <!-- options-start -->
 * Valid options are: <p/>
 * 
 * <pre> -D
 *  Output detailed info for nominal attributes</pre>
 * 
 <!-- options-end -->
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 9049 $
 */
public class CorrelationAttributeEval extends ASEvaluation implements AttributeEvaluator, OptionHandler {

    /** For serialization */
    private static final long serialVersionUID = -4931946995055872438L;

    /** The correlation for each attribute */
    protected double[] m_correlations;

    /** Whether to output detailed (per value) correlation for nominal attributes */
    protected boolean m_detailedOutput = false;

    /** Holds the detailed output info */
    protected StringBuffer m_detailedOutputBuff;

    /**
     * Returns a string describing this attribute evaluator
     * 
     * @return a description of the evaluator suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String globalInfo() {
        return "CorrelationAttributeEval :\n\nEvaluates the worth of an attribute "
                + "by measuring the correlation (Pearson's) between it and the class.\n\n"
                + "Nominal attributes are considered on a value by "
                + "value basis by treating each value as an indicator. An overall "
                + "correlation for a nominal attribute is arrived at via a weighted average.\n";
    }

    /**
     * Returns the capabilities of this evaluator.
     * 
     * @return the capabilities of this evaluator
     * @see Capabilities
     */
    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enable(Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capability.DATE_ATTRIBUTES);
        result.enable(Capability.MISSING_VALUES);

        // class
        result.enable(Capability.NOMINAL_CLASS);
        result.enable(Capability.MISSING_CLASS_VALUES);

        return result;
    }

    /**
     * Returns an enumeration describing the available options.
     * 
     * @return an enumeration of all the available options.
     **/
    @Override
    public Enumeration listOptions() {
        // TODO Auto-generated method stub
        Vector<Option> newVector = new Vector<Option>();

        newVector.addElement(new Option("\tOutput detailed info for nominal attributes", "D", 0, "-D"));

        return newVector.elements();
    }

    /**
     * Parses a given list of options. <p/>
     * 
     <!-- options-start -->
     * Valid options are: <p/>
     * 
     * <pre> -D
     *  Output detailed info for nominal attributes</pre>
     * 
     <!-- options-end -->
     * 
     * @param options the list of options as an array of strings
     * @throws Exception if an option is not supported
     */
    @Override
    public void setOptions(String[] options) throws Exception {

        setOutputDetailedInfo(Utils.getFlag('D', options));
    }

    /**
     * Gets the current settings of WrapperSubsetEval.
     * 
     * @return an array of strings suitable for passing to setOptions()
     */
    @Override
    public String[] getOptions() {
        String[] options = new String[1];

        if (getOutputDetailedInfo()) {
            options[0] = "-D";
        } else {
            options[0] = "";
        }

        return options;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String outputDetailedInfoTipText() {
        return "Output per value correlation for nominal attributes";
    }

    /**
     * Set whether to output per-value correlation for nominal attributes
     * 
     * @param d true if detailed (per-value) correlation is to be output for
     *          nominal attributes
     */
    public void setOutputDetailedInfo(boolean d) {
        m_detailedOutput = d;
    }

    /**
     * Get whether to output per-value correlation for nominal attributes
     * 
     * @return true if detailed (per-value) correlation is to be output for
     *         nominal attributes
     */
    public boolean getOutputDetailedInfo() {
        return m_detailedOutput;
    }

    /**
     * Evaluates an individual attribute by measuring the correlation (Pearson's)
     * between it and the class. Nominal attributes are considered on a value by
     * value basis by treating each value as an indicator. An overall correlation
     * for a nominal attribute is arrived at via a weighted average.
     * 
     * @param attribute the index of the attribute to be evaluated
     * @return the correlation
     * @throws Exception if the attribute could not be evaluated
     */
    @Override
    public double evaluateAttribute(int attribute) throws Exception {
        return m_correlations[attribute];
    }

    /**
     * Describe the attribute evaluator
     * 
     * @return a description of the attribute evaluator as a String
     */
    @Override
    public String toString() {
        StringBuffer buff = new StringBuffer();

        if (m_correlations == null) {
            buff.append("Correlation attribute evaluator has not been built yet.");
        } else {
            buff.append("\tCorrelation Ranking Filter");

            if (m_detailedOutput && m_detailedOutputBuff.length() > 0) {
                buff.append("\n\tDetailed output for nominal attributes");
                buff.append(m_detailedOutputBuff);
            }
        }

        return buff.toString();
    }

    /**
     * Initializes an information gain attribute evaluator. Replaces missing
     * values with means/modes; Deletes instances with missing class values.
     * 
     * @param data set of instances serving as training data
     * @throws Exception if the evaluator has not been generated successfully
     */
    @Override
    public void buildEvaluator(Instances data) throws Exception {
        data = new Instances(data);
        data.deleteWithMissingClass();

        ReplaceMissingValues rmv = new ReplaceMissingValues();
        rmv.setInputFormat(data);
        data = Filter.useFilter(data, rmv);

        int numClasses = data.classAttribute().numValues();
        int classIndex = data.classIndex();
        int numInstances = data.numInstances();
        m_correlations = new double[data.numAttributes()];
        /*
         * boolean hasNominals = false; boolean hasNumerics = false;
         */
        List<Integer> numericIndexes = new ArrayList<Integer>();
        List<Integer> nominalIndexes = new ArrayList<Integer>();
        if (m_detailedOutput) {
            m_detailedOutputBuff = new StringBuffer();
        }

        // TODO for instance weights (folded into computing weighted correlations)
        // add another dimension just before the last [2] (0 for 0/1 binary vector
        // and
        // 1 for corresponding instance weights for the 1's)
        double[][][] nomAtts = new double[data.numAttributes()][][];
        for (int i = 0; i < data.numAttributes(); i++) {
            if (data.attribute(i).isNominal() && i != classIndex) {
                nomAtts[i] = new double[data.attribute(i).numValues()][data.numInstances()];
                Arrays.fill(nomAtts[i][0], 1.0); // set zero index for this att to all
                                                 // 1's
                nominalIndexes.add(i);
            } else if (data.attribute(i).isNumeric() && i != classIndex) {
                numericIndexes.add(i);
            }
        }

        // do the nominal attributes
        if (nominalIndexes.size() > 0) {
            for (int i = 0; i < data.numInstances(); i++) {
                Instance current = data.instance(i);
                for (int j = 0; j < current.numValues(); j++) {
                    if (current.attribute(current.index(j)).isNominal() && current.index(j) != classIndex) {
                        // Will need to check for zero in case this isn't a sparse
                        // instance (unless we add 1 and subtract 1)
                        nomAtts[current.index(j)][(int) current.valueSparse(j)][i] += 1;
                        nomAtts[current.index(j)][0][i] -= 1;
                    }
                }
            }
        }

        if (data.classAttribute().isNumeric()) {
            double[] classVals = data.attributeToDoubleArray(classIndex);

            // do the numeric attributes
            for (Integer i : numericIndexes) {
                double[] numAttVals = data.attributeToDoubleArray(i);
                m_correlations[i] = Utils.correlation(numAttVals, classVals, numAttVals.length);

                if (m_correlations[i] == 1.0) {
                    // check for zero variance (useless numeric attribute)
                    if (Utils.variance(numAttVals) == 0) {
                        m_correlations[i] = 0;
                    }
                }
            }

            // do the nominal attributes
            if (nominalIndexes.size() > 0) {

                // now compute the correlations for the binarized nominal attributes
                for (Integer i : nominalIndexes) {
                    double sum = 0;
                    double corr = 0;
                    double sumCorr = 0;
                    double sumForValue = 0;

                    if (m_detailedOutput) {
                        m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name());
                    }

                    for (int j = 0; j < data.attribute(i).numValues(); j++) {
                        sumForValue = Utils.sum(nomAtts[i][j]);
                        corr = Utils.correlation(nomAtts[i][j], classVals, classVals.length);

                        // useless attribute - all instances have the same value
                        if (sumForValue == numInstances || sumForValue == 0) {
                            corr = 0;
                        }
                        if (corr < 0.0) {
                            corr = -corr;
                        }
                        sumCorr += sumForValue * corr;
                        sum += sumForValue;

                        if (m_detailedOutput) {
                            m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": ");
                            m_detailedOutputBuff.append(Utils.doubleToString(corr, 6));
                        }
                    }
                    m_correlations[i] = (sum > 0) ? sumCorr / sum : 0;
                }
            }
        } else {
            // class is nominal
            // TODO extra dimension for storing instance weights too
            double[][] binarizedClasses = new double[data.classAttribute().numValues()][data.numInstances()];

            // this is equal to the number of instances for all inst weights = 1
            double[] classValCounts = new double[data.classAttribute().numValues()];

            for (int i = 0; i < data.numInstances(); i++) {
                Instance current = data.instance(i);
                binarizedClasses[(int) current.classValue()][i] = 1;
            }
            for (int i = 0; i < data.classAttribute().numValues(); i++) {
                classValCounts[i] = Utils.sum(binarizedClasses[i]);
            }

            double sumClass = Utils.sum(classValCounts);

            // do numeric attributes first
            if (numericIndexes.size() > 0) {
                for (Integer i : numericIndexes) {
                    double[] numAttVals = data.attributeToDoubleArray(i);
                    double corr = 0;
                    double sumCorr = 0;

                    for (int j = 0; j < data.classAttribute().numValues(); j++) {
                        corr = Utils.correlation(numAttVals, binarizedClasses[j], numAttVals.length);
                        if (corr < 0.0) {
                            corr = -corr;
                        }

                        if (corr == 1.0) {
                            // check for zero variance (useless numeric attribute)
                            if (Utils.variance(numAttVals) == 0) {
                                corr = 0;
                            }
                        }

                        sumCorr += classValCounts[j] * corr;
                    }
                    m_correlations[i] = sumCorr / sumClass;
                }
            }

            if (nominalIndexes.size() > 0) {
                for (Integer i : nominalIndexes) {
                    if (m_detailedOutput) {
                        m_detailedOutputBuff.append("\n\n").append(data.attribute(i).name());
                    }

                    double sumForAtt = 0;
                    double corrForAtt = 0;
                    for (int j = 0; j < data.attribute(i).numValues(); j++) {
                        double sumForValue = Utils.sum(nomAtts[i][j]);
                        double corr = 0;
                        double sumCorr = 0;
                        double avgCorrForValue = 0;

                        sumForAtt += sumForValue;
                        for (int k = 0; k < numClasses; k++) {

                            // corr between value j and class k
                            corr = Utils.correlation(nomAtts[i][j], binarizedClasses[k],
                                    binarizedClasses[k].length);

                            // useless attribute - all instances have the same value
                            if (sumForValue == numInstances || sumForValue == 0) {
                                corr = 0;
                            }
                            if (corr < 0.0) {
                                corr = -corr;
                            }
                            sumCorr += classValCounts[k] * corr;
                        }
                        avgCorrForValue = sumCorr / sumClass;
                        corrForAtt += sumForValue * avgCorrForValue;

                        if (m_detailedOutput) {
                            m_detailedOutputBuff.append("\n\t").append(data.attribute(i).value(j)).append(": ");
                            m_detailedOutputBuff.append(Utils.doubleToString(avgCorrForValue, 6));
                        }
                    }

                    // the weighted average corr for att i as
                    // a whole (wighted by value frequencies)
                    m_correlations[i] = (sumForAtt > 0) ? corrForAtt / sumForAtt : 0;
                }
            }
        }

        if (m_detailedOutputBuff != null && m_detailedOutputBuff.length() > 0) {
            m_detailedOutputBuff.append("\n");
        }
    }

    /**
     * Returns the revision string.
     * 
     * @return            the revision
     */
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9049 $");
    }

    /**
     * Main method for testing this class.
     * 
     * @param args the options
     */
    public static void main(String[] args) {
        runEvaluator(new CorrelationAttributeEval(), args);
    }
}