net.sf.bddbddb.order.WekaInterface.java Source code

Java tutorial

Introduction

Here is the source code for net.sf.bddbddb.order.WekaInterface.java

Source

// WekaInterface.java, created Oct 31, 2004 1:17:46 AM by joewhaley
// Copyright (C) 2004 John Whaley <jwhaley@alum.mit.edu>
// Licensed under the terms of the GNU LGPL; see COPYING for details.
package net.sf.bddbddb.order;

import java.util.Collection;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;

import jwutil.collections.Pair;
import jwutil.util.Assert;
import net.sf.bddbddb.Attribute;
import net.sf.bddbddb.FindBestDomainOrder;
import net.sf.bddbddb.InferenceRule;
import net.sf.bddbddb.Variable;
import net.sf.bddbddb.order.OrderConstraint.AfterConstraint;
import net.sf.bddbddb.order.OrderConstraint.BeforeConstraint;
import net.sf.bddbddb.order.OrderConstraint.InterleaveConstraint;
import weka.classifiers.Classifier;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/**
 * WekaInterface
 * 
 * @author jwhaley
 * @version $Id: WekaInterface.java 531 2005-04-29 06:39:10Z joewhaley $
 */
public abstract class WekaInterface {

    public static OrderAttribute makeOrderAttribute(OrderConstraint c) {
        return new OrderAttribute(c.a, c.b);
    }

    public static OrderAttribute makeOrderAttribute(Object a, Object b) {
        if (OrderConstraint.compare(a, b)) {
            return new OrderAttribute(a, b);
        } else {
            return new OrderAttribute(b, a);
        }
    }

    public static int INTERLEAVE = 1;

    public static int getType(OrderConstraint oc) {
        if (oc instanceof BeforeConstraint)
            return 0;
        else if (oc instanceof InterleaveConstraint)
            return INTERLEAVE;
        else if (oc instanceof AfterConstraint)
            return 2;
        else
            return -1;
    }

    public static class OrderAttribute extends weka.core.Attribute {
        /**
         * Version ID for serialization.
         */
        private static final long serialVersionUID = 3257291339690555447L;

        Object a, b;

        static FastVector my_nominal_values = new FastVector(3);
        static {
            my_nominal_values.addElement("<");
            my_nominal_values.addElement("~");
            my_nominal_values.addElement(">");
        }

        private OrderAttribute(Object a, Object b) {
            super(a + "," + b, my_nominal_values);
            this.a = a;
            this.b = b;
        }

        public OrderConstraint getConstraint(int k) {
            switch (k) {
            case 0:
                return OrderConstraint.makePrecedenceConstraint(a, b);
            case 1:
                return OrderConstraint.makeInterleaveConstraint(a, b);
            case 2:
                return OrderConstraint.makePrecedenceConstraint(b, a);
            default:
                return null;
            }
        }

        public OrderConstraint getConstraint(weka.core.Instance i) {
            int k = (int) i.value(this);
            return getConstraint(k);
        }

    }

    public static void addAllPairs(FastVector v, Collection c) {
        Collection pairs = new HashSet();
        for (Iterator i = c.iterator(); i.hasNext();) {
            Object a = i.next();
            Iterator j = c.iterator();
            while (j.hasNext() && j.next() != a)
                ;
            while (j.hasNext()) {
                Object b = j.next();
                UnorderedPair pair = new UnorderedPair(a, b);
                if (pairs.contains(pair))
                    continue;
                OrderAttribute oa = makeOrderAttribute(a, b);
                v.addElement(oa);
                pairs.add(pair);
            }
        }
        //     System.out.println(new HashSet(c) + " Size: " + v.size());

    }

    public static Collection /*UnorderedPair*/ generateAllPairs(Collection c) {
        Collection pairs = new HashSet();
        for (Iterator i = c.iterator(); i.hasNext();) {
            Object a = i.next();
            Iterator j = c.iterator();
            while (j.hasNext() && j.next() != a)
                ;
            while (j.hasNext()) {
                Object b = j.next();
                Pair pair = new UnorderedPair(a, b);
                pairs.add(pair);
            }
        }
        return pairs;
    }

    public static FastVector constructVarAttributes(Collection vars) {
        FastVector v = new FastVector();
        addAllPairs(v, vars);
        return v;
    }

    public static FastVector constructAttribAttributes(InferenceRule ir, Collection vars) {
        Collection attribs = new LinkedList();
        for (Iterator i = vars.iterator(); i.hasNext();) {
            Variable v = (Variable) i.next();
            Attribute a = ir.getAttribute(v);
            if (a != null)
                attribs.add(a);
        }
        FastVector v = new FastVector();
        addAllPairs(v, attribs);
        return v;
    }

    public static FastVector constructDomainAttributes(InferenceRule ir, Collection vars) {
        Collection domains = new LinkedList();
        for (Iterator i = vars.iterator(); i.hasNext();) {
            Variable v = (Variable) i.next();
            Attribute a = ir.getAttribute(v);
            if (a != null)
                domains.add(a.getDomain());
        }
        FastVector v = new FastVector();
        addAllPairs(v, domains);
        return v;
    }

    public static weka.core.Attribute makeBucketAttribute(int numClusters) {
        FastVector clusterValues = new FastVector(numClusters);
        for (int i = 0; i < numClusters; ++i)
            clusterValues.addElement(Integer.toString(i));
        return new weka.core.Attribute("costBucket", clusterValues);
    }

    public static Classifier buildClassifier(String cClassName, Instances data) {
        // Build the classifier.
        Classifier classifier = null;
        try {
            long time = System.currentTimeMillis();
            classifier = (Classifier) Class.forName(cClassName).newInstance();
            classifier.buildClassifier(data);
            if (FindBestDomainOrder.TRACE > 1)
                System.out.println("Classifier " + cClassName + " took " + (System.currentTimeMillis() - time)
                        + " ms to build.");
            if (FindBestDomainOrder.TRACE > 2)
                System.out.println(classifier);
        } catch (Exception x) {
            FindBestDomainOrder.out.println(cClassName + ": " + x.getLocalizedMessage());
            return null;
        }
        return classifier;
    }

    public static double leaveOneOutCV(Instances data, String cClassName) {
        return WekaInterface.cvError(data.numInstances(), data, cClassName);
    }

    public static double cvError(int numFolds, Instances data0, String cClassName) {
        if (data0.numInstances() < numFolds)
            return Double.NaN; //more folds than elements
        if (numFolds == 0)
            return Double.NaN; // no folds
        if (data0.numInstances() == 0)
            return 0; //no instances

        Instances data = new Instances(data0);
        //data.randomize(new Random(System.currentTimeMillis()));
        data.stratify(numFolds);
        Assert._assert(data.classAttribute() != null);
        double[] estimates = new double[numFolds];
        for (int i = 0; i < numFolds; ++i) {
            Instances trainData = data.trainCV(numFolds, i);
            Assert._assert(trainData.classAttribute() != null);
            Assert._assert(trainData.numInstances() != 0, "Cannot train classifier on 0 instances.");

            Instances testData = data.testCV(numFolds, i);
            Assert._assert(testData.classAttribute() != null);
            Assert._assert(testData.numInstances() != 0, "Cannot test classifier on 0 instances.");

            int temp = FindBestDomainOrder.TRACE;
            FindBestDomainOrder.TRACE = 0;
            Classifier classifier = buildClassifier(cClassName, trainData);
            FindBestDomainOrder.TRACE = temp;
            int count = testData.numInstances();
            double loss = 0;
            double sum = 0;
            for (Enumeration e = testData.enumerateInstances(); e.hasMoreElements();) {
                Instance instance = (Instance) e.nextElement();
                Assert._assert(instance != null);
                Assert._assert(instance.classAttribute() != null
                        && instance.classAttribute() == trainData.classAttribute());
                try {
                    double testClass = classifier.classifyInstance(instance);
                    double weight = instance.weight();
                    if (testClass != instance.classValue())
                        loss += weight;
                    sum += weight;
                } catch (Exception ex) {
                    FindBestDomainOrder.out.println("Exception while classifying: " + instance + "\n" + ex);
                }
            }
            estimates[i] = 1 - loss / sum;
        }
        double average = 0;
        for (int i = 0; i < numFolds; ++i)
            average += estimates[i];

        return average / numFolds;
    }

    public static TrialInstances binarize(double classValue, TrialInstances data) {
        TrialInstances newInstances = data.infoClone();
        weka.core.Attribute newAttr = makeBucketAttribute(2);
        TrialInstances.setIndex(newAttr, newInstances.classIndex());
        newInstances.setClass(newAttr);
        newInstances.setClassIndex(data.classIndex());
        for (Enumeration e = data.enumerateInstances(); e.hasMoreElements();) {
            TrialInstance instance = (TrialInstance) e.nextElement();
            TrialInstance newInstance = TrialInstance.cloneInstance(instance);
            newInstance.setDataset(newInstances);
            if (instance.classValue() <= classValue) {
                newInstance.setClassValue(0);
            } else {
                newInstance.setClassValue(1);
            }
            newInstances.add(newInstance);
        }
        return newInstances;
    }

    public static class OrderInstance extends Instance {

        /**
         * Version ID for serialization.
         */
        private static final long serialVersionUID = 3258412811553093939L;

        public static OrderInstance construct(Order o, Instances dataSet) {
            return construct(o, dataSet, 1);
        }

        public static OrderInstance construct(Order o, Instances dataSet, double weight) {
            double[] d = new double[dataSet.numAttributes()];
            for (int i = 0; i < d.length; ++i) {
                d[i] = Instance.missingValue();
            }
            for (Iterator i = o.getConstraints().iterator(); i.hasNext();) {
                OrderConstraint oc = (OrderConstraint) i.next();
                // TODO: use a map from Pair to int instead of building String and doing linear search.

                String cName = oc.getFirst() + "," + oc.getSecond();
                OrderAttribute oa = (OrderAttribute) dataSet.attribute(cName);
                if (oa != null) {

                    if (oc.getFirst().equals(oc.getSecond()) && d[oa.index()] == INTERLEAVE)
                        continue;
                    /* TODO should only one type of constraint for
                     * when first == second and they are not interleaved
                     */
                    d[oa.index()] = getType(oc);
                } else {

                    System.out.println("Warning: while building OrderInstance for " + o
                            + " couldn't find constraint " + oc + " in data set");
                    System.out.println("dataset\n: " + dataSet);
                    Assert.UNREACHABLE();
                }
            }
            return new OrderInstance(weight, d, o);
        }

        protected Order o;

        protected OrderInstance(double w, double[] d, Order o) {
            super(w, d);
            this.o = o;
        }

        protected OrderInstance(OrderInstance that) {
            super(that);
            this.o = that.o;
        }

        public Object copy() {
            return new OrderInstance(this);
        }

        public Order getOrder() {
            return o;
        }

    }

}