ca.uottawa.balie.WekaLearner.java Source code

Java tutorial

Introduction

Here is the source code for ca.uottawa.balie.WekaLearner.java

Source

/*
 * Balie - BAseLine Information Extraction
 * Copyright (C) 2004-2007  David Nadeau
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

/*
 * Created on Apr 10, 2004
 */
package ca.uottawa.balie;

import java.io.Serializable;
import java.util.Random;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.clusterers.Clusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/**
 * Methods to create, train and test a classification algorithm.
 * 
 * @author David Nadeau (pythonner@gmail.com)
 */
public class WekaLearner implements Serializable {

    private static final long serialVersionUID = 1L;

    // Local Constant
    private static final boolean DEBUG = false;
    private static final String CLASS_LABEL = "Class";

    /**
     * Creates a new classification algorithm.
     * 
     * @param pi_Attributes         Array of attributes
     * @param pi_ClassAttributes   Class attribute
     */
    public WekaLearner(WekaAttribute[] pi_Attributes, String[] pi_ClassAttributes) {
        m_ClassAttributes = pi_ClassAttributes;
        m_AttributeLabels = new String[pi_Attributes.length];

        // Create the Weka Attributes
        m_WekaAttributes = new FastVector(pi_Attributes.length + 1);
        for (int i = 0; i != pi_Attributes.length; ++i) {
            if (pi_Attributes[i].Label().equals(CLASS_LABEL)) {
                throw new Error("Attribute cannot be named \"Class\" (reserved)");
            }
            m_AttributeLabels[i] = pi_Attributes[i].Label();
            if (pi_Attributes[i].IsNumeric()) {
                Attribute anAttribute = new Attribute(pi_Attributes[i].Label());
                m_WekaAttributes.addElement(anAttribute);
            } else {
                Attribute anAttribute = new Attribute(pi_Attributes[i].Label(), pi_Attributes[i].Values());
                m_WekaAttributes.addElement(anAttribute);
            }
        }

        // Create The class attribute 
        FastVector fvClass = new FastVector(pi_ClassAttributes.length);
        for (int i = 0; i != pi_ClassAttributes.length; ++i) {
            fvClass.addElement(pi_ClassAttributes[i]);
        }
        Attribute aClass = new Attribute(CLASS_LABEL, fvClass);

        m_WekaAttributes.addElement(aClass);

        m_TrainingSet = new Instances("", m_WekaAttributes, 0);
        m_TrainingSet.setClassIndex(aClass.index());

        m_TestingSet = new Instances("", m_WekaAttributes, 0);
        m_TestingSet.setClassIndex(aClass.index());

        // create an empty confusion matrix.. will be populated at evaluation time
        m_ConfusionMatrix = new double[0][0];
        m_bDoubleOnly = false;
    }

    public WekaLearner(FastVector attrsMerged, String[] attrlblMerged, String[] classList, Instances trainMerged) {
        m_AttributeLabels = attrlblMerged;
        m_ClassAttributes = classList;
        m_WekaAttributes = attrsMerged;
        m_TrainingSet = trainMerged;
        m_TestingSet = new Instances("", m_WekaAttributes, 0);
        m_TestingSet.setClassIndex(m_WekaAttributes.size() - 1);
    }

    /**
    * Adds an instance in the train set.
    * 
    * @param pi_Instance   The instance, an array of objects (can mix numeric and nominal attributes - see {@link WekaAttribute})
    * @param pi_Class      The class of this instance
    * @see WekaAttribute
    */
    public void AddTrainInstance(Object[] pi_Instance, String pi_Class) {
        m_TrainingSet.add(CreateInstance(pi_Instance, pi_Class));
    }

    public void AddUnlabeledTrainInstance(Object[] pi_Instance) {
        m_TrainingSet.add(CreateUnlabeledInstance(pi_Instance));
    }

    /**
     * Adds an instance in the test set.
     * 
     * @param pi_Instance    The instance, an array of objects (can mix numeric and nominal attributes - see {@link WekaAttribute})
     * @param pi_Class       The class of this instance
     * @see WekaAttribute
     */
    public void AddTestInstance(Object[] pi_Instance, String pi_Class) {
        m_TestingSet.add(CreateInstance(pi_Instance, pi_Class));
    }

    private Instance CreateInstance(Object[] pi_Instance, String pi_Class) {
        if (m_bDoubleOnly) {
            return CreateDoubleInstance(pi_Instance, pi_Class);
        } else {
            return CreateInstance2(pi_Instance, pi_Class);
        }
    }

    public void SetDoubleOnly(boolean pi_Flag) {
        m_bDoubleOnly = pi_Flag; // mega-optimization
    }

    private Instance CreateInstance2(Object[] pi_Instance, String pi_Class) {
        Instance inst = new Instance(pi_Instance.length + 1);
        for (int i = 0; i != pi_Instance.length; ++i) {
            if (pi_Instance[i] instanceof Double) {
                inst.setValue((Attribute) m_WekaAttributes.elementAt(i), ((Double) pi_Instance[i]).doubleValue());
            } else if (pi_Instance[i] instanceof String) {
                try {
                    inst.setValue((Attribute) m_WekaAttributes.elementAt(i), (String) pi_Instance[i]);
                } catch (Exception e) {
                    if (DEBUG)
                        DebugInfo.Out("Unknow attribute for learner: " + (String) pi_Instance[i]);
                    inst.setMissing((Attribute) m_WekaAttributes.elementAt(i));
                }

            }
        }
        inst.setValue((Attribute) m_WekaAttributes.lastElement(), pi_Class);
        return inst;
    }

    private Instance CreateDoubleInstance(Object[] pi_Instance, String pi_Class) {
        double[] instval = new double[pi_Instance.length];
        for (int i = 0; i != pi_Instance.length; ++i) {
            instval[i] = ((Double) pi_Instance[i]).doubleValue();
        }
        Instance inst = new Instance(1, instval);
        inst.insertAttributeAt(inst.numAttributes());
        inst.setValue((Attribute) m_WekaAttributes.lastElement(), pi_Class);
        return inst;
    }

    public Instance CreateUnlabeledInstance(Object[] pi_Instance) {
        Instance inst = new Instance(pi_Instance.length + 1);
        for (int i = 0; i != pi_Instance.length; ++i) {
            if (pi_Instance[i] instanceof Double) {
                inst.setValue((Attribute) m_WekaAttributes.elementAt(i), ((Double) pi_Instance[i]).doubleValue());
            } else if (pi_Instance[i] instanceof String) {
                try {
                    inst.setValue((Attribute) m_WekaAttributes.elementAt(i), (String) pi_Instance[i]);
                } catch (Exception e) {
                    if (DEBUG)
                        DebugInfo.Out("Unknow attribute for learner: " + (String) pi_Instance[i]);
                    inst.setMissing((Attribute) m_WekaAttributes.elementAt(i));
                }

            }
        }
        inst.setMissing((Attribute) m_WekaAttributes.lastElement());
        return inst;
    }

    /**
     * Creates the classification model by learning from the training set.
     * 
     * @param pi_Classifier   The classification algorithm (from the Weka library, ex.: NaiveBayes, J48, ...)
     */
    public void CreateModel(Classifier pi_Classifier) {
        if (DEBUG)
            DebugInfo.Out("Training on " + m_TrainingSet.numInstances() + " instances");
        try {
            m_Scheme = pi_Classifier;
            m_Scheme.buildClassifier(m_TrainingSet);
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }

    /**
     * Creates the cluster model by learning from the training set.
     * 
     * @param pi_Clusterer   The clusterization algorithm (from the Weka library, ex.: k-Mean, ...)
     */
    public void CreateModel(Clusterer pi_Clusterer) {
        if (DEBUG)
            DebugInfo.Out("Clustering " + m_TrainingSet.numInstances() + " instances");
        try {
            m_ClusterScheme = pi_Clusterer;
            m_ClusterScheme.buildClusterer(m_TrainingSet);
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }

    /**
     * Approximate training set error.
     * 
     * @return evaluation module from which many types of errors are exposed (e.g.: mean absolute error)
     */
    public Evaluation EstimateConfidence() {
        Evaluation evaluation = null;
        try {
            evaluation = new Evaluation(m_TrainingSet);
            evaluation.crossValidateModel(m_Scheme, m_TrainingSet, 10, new Random());
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
        // which error is the best? depends on the application.
        return evaluation;
    }

    /**
     * Test the learned model.
     * 
     * @return A summary string of the performance of the classifier
     */
    public String TestModel() {
        if (DEBUG)
            DebugInfo.Out("Testing on " + m_TestingSet.numInstances() + " instances");
        Evaluation evaluation = null;
        try {
            evaluation = new Evaluation(m_TrainingSet);
            evaluation.evaluateModel(m_Scheme, m_TestingSet);
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
        String strSummary = evaluation.toSummaryString();
        strSummary += "\n\nConfusion Matrix: \n\n";
        m_ConfusionMatrix = evaluation.confusionMatrix();
        for (int i = 0; i != m_ConfusionMatrix.length; ++i) {
            for (int j = 0; j != m_ConfusionMatrix[i].length; ++j) {
                strSummary += String.valueOf(m_ConfusionMatrix[i][j]) + "\t";
            }
            strSummary += "\n";
        }
        return strSummary;
    }

    /**
     * Gets the confusion matrix that plots precision and recall for each class.
     * 
     * @return 2-dimension array (X=Ground thruth, Y=Guesses)
     */
    public double[][] ConfusionMatrix() {
        return m_ConfusionMatrix;
    }

    /**
     * Classify an unseen instance using the learned classifier.
     * Output the most probable class.
     * 
     * @param pi_Instance The instance, an array of objects (can mix numeric and nominal attributes - see {@link WekaAttribute})
     * @return The index of the most probable class
     * @see WekaAttribute
     */
    public double Classify(Object[] pi_Instance) {
        Instance inst = CreateInstance(pi_Instance);
        inst.setDataset(m_TrainingSet);
        double out = 0;
        try {
            out = m_Scheme.classifyInstance(inst);
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
        return out;
    }

    /**
     * Cluster an unseen instance using the learned clusterizer.
     * Output the most probable cluster.
     * 
     * @param pi_Instance The instance, an array of objects (can mix numeric and nominal attributes - see {@link WekaAttribute})
     * @return The index of the most probable cluster
     * @see WekaAttribute
     */
    public double Cluster(Object[] pi_Instance) {
        Instance inst = CreateInstance(pi_Instance);
        inst.setDataset(m_TrainingSet);
        double out = 0;
        try {
            out = m_ClusterScheme.clusterInstance(inst);
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
        return out;
    }

    private Instance CreateInstance(Object[] pi_Instance) {
        Instance inst = new Instance(pi_Instance.length + 1);
        for (int i = 0; i != pi_Instance.length; ++i) {
            if (pi_Instance[i] instanceof Double) {
                inst.setValue((Attribute) m_WekaAttributes.elementAt(i), ((Double) pi_Instance[i]).doubleValue());
            } else if (pi_Instance[i] instanceof String) {
                try {
                    inst.setValue((Attribute) m_WekaAttributes.elementAt(i), (String) pi_Instance[i]);
                } catch (Exception e) {
                    if (DEBUG)
                        DebugInfo.Out("Unknow attribute for learner: " + (String) pi_Instance[i]);
                    inst.setMissing((Attribute) m_WekaAttributes.elementAt(i));
                }
            }
        }
        return inst;
    }

    /**
     * Classify an unseen instance using the learned classifier.
     * Ouput the likelihood of each class.
     * 
     * @param pi_Instance The instance, an array of objects (can mix numeric and nominal attributes - see {@link WekaAttribute})
     * @return An array of probabilities, parallel to the possible classes
     * @see WekaAttribute
     */
    public double[] GetDistribution(Object[] pi_Instance) {
        Instance inst = new Instance(pi_Instance.length + 1);
        for (int i = 0; i != pi_Instance.length; ++i) {
            if (pi_Instance[i] instanceof Double) {
                inst.setValue((Attribute) m_WekaAttributes.elementAt(i), ((Double) pi_Instance[i]).doubleValue());
            } else if (pi_Instance[i] instanceof String) {
                try {
                    inst.setValue((Attribute) m_WekaAttributes.elementAt(i), (String) pi_Instance[i]);
                } catch (Exception e) {
                    if (DEBUG)
                        DebugInfo.Out("Unknow attribute for learner: " + (String) pi_Instance[i]);
                    inst.setMissing((Attribute) m_WekaAttributes.elementAt(i));
                }
            }
        }
        inst.setDataset(m_TrainingSet);
        double fDistribution[] = null;
        try {
            fDistribution = m_Scheme.distributionForInstance(inst);
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
        return fDistribution;
    }

    /**
     * Classify an unseen instance using the learned classifier.
     * Get the likelihood of a given class.
     * 
     * @param pi_Instance         The instance, an array of objects (can mix numeric and nominal attributes - see {@link WekaAttribute})
     * @param pi_PositiveClassIndex   The index of a class
     * @return the likelihood of the given class
     * @see WekaAttribute
     */
    public double Likelihood(Object[] pi_Instance, int pi_PositiveClassIndex) {

        double fDistribution[] = GetDistribution(pi_Instance);

        double fSum = 0.0;
        for (int i = 0; i != fDistribution.length; ++i) {
            fSum += fDistribution[i];
        }

        return fDistribution[pi_PositiveClassIndex] / fSum;
    }

    /**
     * Gets the list of attributes.
     * 
     * @return String array
     */
    public String[] GetAttributeList() {
        return m_AttributeLabels;
    }

    public FastVector GetAttribute() {
        return m_WekaAttributes;
    }

    /**
     * Gets the list of classes.
     * 
     * @return String array
     */
    public String[] GetClassList() {
        return m_ClassAttributes;
    }

    protected Instances GetTrainInstances() {
        return m_TrainingSet;
    }

    protected Instances GetTestInstances() {
        return m_TestingSet;
    }

    /**
      * Get the training instances
     * @return training instances
     */
    public Instances GetTrainingSet() {
        return m_TrainingSet;
    }

    /**
     * Get the training instances
     * @return training instances
     */
    public Instances GetTestingSet() {
        return m_TestingSet;
    }

    /**
      * Get the number of training instances
     * @return training set size
     */
    public int GetTrainingSetSize() {
        return m_TrainingSet.numInstances();
    }

    public Instances GetClusterCentroid() {
        if (m_ClusterScheme instanceof SimpleKMeans) {
            return ((SimpleKMeans) m_ClusterScheme).getClusterCentroids();
        } else {
            throw new Error("Only available for SimpleKMeans clusterer");
        }
    }

    /**
     * Reduces the size of a classifier by deleting the corpora.
     * Only keep the learned model and information on classes and attributes.
     */
    public void Shrink() {
        m_TrainingSet.delete();
        m_TestingSet.delete();
    }

    private String[] m_AttributeLabels;
    private String[] m_ClassAttributes;
    private FastVector m_WekaAttributes;
    private Instances m_TrainingSet;
    private Instances m_TestingSet;
    private Classifier m_Scheme;
    private Clusterer m_ClusterScheme;
    private double[][] m_ConfusionMatrix;
    private boolean m_bDoubleOnly;

    /**
     * Test routine
     * 
     * @param args
     */
    public static void main(String[] args) {
    }

}