net.sf.bddbddb.OrderClassifier.java Source code

Java tutorial

Introduction

Here is the source code for net.sf.bddbddb.OrderClassifier.java

Source

// OrderClassifier.java, created Jul 27, 2004 8:57:36 PM 2004 by mcarbin
// Copyright (C) 2004 Michael Carbin <mcarbin@stanford.edu>
// Licensed under the terms of the GNU LGPL; see COPYING for details.
package net.sf.bddbddb;

import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.StringTokenizer;
import java.util.TreeSet;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import jwutil.collections.GenericMultiMap;
import jwutil.collections.MultiMap;
import jwutil.collections.Pair;
import jwutil.util.Assert;
import net.sf.javabdd.BDDDomain;
import weka.classifiers.Classifier;
import weka.classifiers.trees.Id3;
import weka.clusterers.Clusterer;
import weka.clusterers.EM;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.PKIDiscretize;

/**
 * OrderClassifier
 * 
 * @author mcarbin
 * @version $Id: OrderClassifier.java 531 2005-04-29 06:39:10Z joewhaley $
 */
public class OrderClassifier {
    private Instances data = null;
    private Classifier classifier;
    private BDDInferenceRule rule;
    private BDDRelation bottomRelation;
    private List relationAttrPairs;
    private MultiMap domainToAttrPairs;
    private int numAttrs;
    private FastVector attributes;
    private FastVector attrOptions;
    private Filter filter;
    private SortedSet orders;
    private List domainSet;
    private int numClusters;
    public static Set goodClusters;
    private long max_wait_time;
    static {
        goodClusters = new HashSet();
        goodClusters.add("0");
    }

    public OrderClassifier(BDDInferenceRule rule, List domainSet, long max_wait_time) {
        this.rule = rule;
        this.bottomRelation = (BDDRelation) rule.bottom.getRelation();
        this.domainSet = domainSet;
        this.max_wait_time = max_wait_time;
        orders = new TreeSet();
        classifier = new Id3();
        relationAttrPairs = new LinkedList();
        domainToAttrPairs = new GenericMultiMap();
        attributes = computeAttributes();
        attributes.addElement(new Attribute("class"));
        /*      FastVector values = new FastVector(2);
         values.addElement("good");
         values.addElement("bad");
             
         attributes.addElement(new Attribute("Time", values));
         */

        // data = new Instances("Ordering Constraints", attributes, 30);

    }

    public double importance(weka.core.Attribute attribute, String attrValue) {//, String classValue){

        int count = 0;
        int goodCount = 0, badCount = 0;
        List newInstances = new LinkedList();

        for (Iterator it = orders.iterator(); it.hasNext();) {
            Instance instance = (Instance) it.next();
            if (//!instance.stringValue(instance.classIndex()).equals(classValue) ||
            !instance.stringValue(attribute).equals(attrValue))
                continue;

            if (goodClusters.contains(instance.stringValue(instance.classIndex())))
                ++goodCount;
            else
                ++badCount;

            Instance newInstance = new Instance(instance);
            newInstance.setDataset(instance.dataset());
            newInstances.add(newInstance);
        }
        goodCount *= attrOptions.size() - 1;
        badCount *= attrOptions.size() - 1;
        for (Iterator it = newInstances.iterator(); it.hasNext();) {
            Instance instance = (Instance) it.next();
            /*      if(//!instance.stringValue(instance.classIndex()).equals(classValue) || 
             !instance.stringValue(attribute).equals(attrValue)) continue;
             */

            String classValue = instance.stringValue(instance.classIndex());
            FastVector newOptions = new FastVector();
            newOptions.appendElements(attrOptions);
            newOptions.removeElementAt(newOptions.indexOf(instance.stringValue(attribute)));
            //int index = Math.abs(LearnedOrder.randomNumGen.nextInt()) % newOptions.size();
            int index = 0;
            while (index < newOptions.size()) {
                instance.setValue(attribute, attrOptions.indexOf(newOptions.elementAt(index)));
                String value = classify(instance);
                if (goodClusters.contains(classValue)) {
                    if (goodClusters.contains(value))
                        --goodCount;
                } else if (!goodClusters.contains(classValue)) {
                    if (!goodClusters.contains(value))
                        --badCount;
                }
                ++index;
            }
            //if(value.equals(classValue)) --count;
        }

        count = goodCount - badCount;
        count /= attrOptions.size() - 1;

        double importance = ((double) count) / newInstances.size();
        if (Double.isNaN(importance))
            return 0;
        return importance;
    }

    public double vote(weka.core.Attribute attribute, String attrValue, String classValue) {
        int count = 0;
        int numOrders = 0;
        for (Iterator it = orders.iterator(); it.hasNext();) {
            Instance instance = (Instance) it.next();
            if (!instance.stringValue(instance.classIndex()).equals(classValue))
                continue;
            ++numOrders;

            if (instance.stringValue(attribute).equals(attrValue))
                ++count;

        }

        return ((double) count) / numOrders;
    }

    public String vote(weka.core.Attribute attribute, String classValue) {

        System.out.print(attribute + " votes: ");
        double before = vote(attribute, "<", classValue);
        System.out.print(" < , " + before);
        double inter = vote(attribute, "=", classValue);
        System.out.print("  = , " + inter);
        double after = vote(attribute, ">", classValue);
        System.out.println("   >, " + after);

        return "";

    }

    public FastVector computeAttributes() {
        for (Iterator it = bottomRelation.getAttributes().iterator(); it.hasNext();) {
            net.sf.bddbddb.Attribute a = (net.sf.bddbddb.Attribute) it.next();
            Pair p = new Pair(bottomRelation, a);
            relationAttrPairs.add(p);
            domainToAttrPairs.add(bottomRelation.getBDDDomain(a), p);
        }

        for (Iterator it = rule.top.iterator(); it.hasNext();) {
            BDDRelation r = (BDDRelation) ((RuleTerm) it.next()).getRelation();
            for (Iterator jt = r.getAttributes().iterator(); jt.hasNext();) {
                net.sf.bddbddb.Attribute a = (net.sf.bddbddb.Attribute) jt.next();
                Pair p = new Pair(r, a);
                relationAttrPairs.add(p);
                domainToAttrPairs.add(r.getBDDDomain(a), p);
            }
        }

        FastVector constraints = new FastVector();
        attrOptions = new FastVector(3);
        attrOptions.addElement("<");
        attrOptions.addElement("=");
        attrOptions.addElement(">");
        for (int i = 0; i < relationAttrPairs.size() - 1; i++) {
            Pair p1 = (Pair) relationAttrPairs.get(i);
            for (int j = i + 1; j < relationAttrPairs.size(); j++) {
                Pair p2 = (Pair) relationAttrPairs.get(j);
                BDDDomain d1 = ((BDDRelation) p1.left).getBDDDomain((net.sf.bddbddb.Attribute) p1.right);
                BDDDomain d2 = ((BDDRelation) p2.left).getBDDDomain((net.sf.bddbddb.Attribute) p2.right);
                if (d1.equals(d2))
                    continue;
                Pair p = new Pair(p1, p2);
                constraints.addElement(new MyAttribute(p, attrOptions));
                //constraints.addElement(new MyAttribute(p,attrOptions, constraints.size()));
            }
        }
        return constraints;
    }

    public void load(BufferedReader in) {
        try {
            String strClusts = in.readLine();
            setNumClasses(Integer.parseInt(strClusts));

            for (;;) {
                String line = in.readLine();
                if (line == null)
                    break;
                Instance instance = new Instance(attributes.size());
                StringTokenizer st = new StringTokenizer(line, ",");
                int i = 0;
                while (st.hasMoreTokens()) {
                    String str = st.nextToken();
                    //   System.out.println("attr " + attributes.elementAt(i) + " val: " + str);
                    instance.setValue((weka.core.Attribute) attributes.elementAt(i), str);
                    ++i;
                }
                line = in.readLine();
                instance = new MyInstance(instance, Long.parseLong(line));
                orders.add(instance);
            }
            buildClassifier();
        } catch (IOException e) {
            e.printStackTrace();
            Assert.UNREACHABLE("Could not load instances");
        } catch (Exception e) {
            e.printStackTrace();
            Assert.UNREACHABLE("Could not build classifier");
        }
    }

    public String toString() {
        return classifier.toString();
    }

    public void save(BufferedWriter out) {
        try {
            out.write(numClusters + "\n");
            for (Iterator it = orders.iterator(); it.hasNext();) {
                MyInstance instance = (MyInstance) it.next();
                out.write(instance.toString() + "\n");
                out.write(instance.time + "\n");
            }
        } catch (IOException e) {
            e.printStackTrace();
            Assert.UNREACHABLE("Couldn't save instances");
        }
    }

    public void setNumClasses(int numClusters) {
        this.numClusters = numClusters;
        attributes.removeElementAt(attributes.size() - 1);
        FastVector clusterValues = new FastVector(numClusters);
        for (int i = 0; i < numClusters; ++i)
            clusterValues.addElement(Integer.toString(i));

        //attributes.addElement(new Attribute("class", clusterValues, attributes.size()));
        attributes.addElement(new Attribute("class", clusterValues));
    }

    public Instances makeInstances() {
        Instances instances = new Instances("Ordering Constraints", attributes, 30);
        instances.setClassIndex(attributes.size() - 1);
        System.out.println(attributes.elementAt(attributes.size() - 1));
        return instances;
    }

    public void buildClassifier() throws Exception {
        System.out.println("Discritizing...");
        Discretizer discretizer = new EqualFreqDiscretizer();//new ClusterDiscretizer();

        data = discretizer.discretize(orders);
        System.out.println("Finished Discritization");
        classifier.buildClassifier(data);
        System.out.println("Finished building classifier");
    }

    public abstract class Discretizer {
        public abstract Instances discretize(SortedSet orders);

        public Instances timeInstances(SortedSet orders) {
            FastVector attributes = new FastVector(1);
            attributes.addElement(new Attribute("time"));

            Instances cInstances = new Instances("wiki", attributes, orders.size());

            for (Iterator it = orders.iterator(); it.hasNext();) {
                // System.out.println("class: " + data.instance(i).getClass());
                MyInstance instance = (MyInstance) it.next();
                double[] values = new double[1]; //instance.toDoubleArray();
                values[values.length - 1] = instance.getTime();
                Instance newInstance = new Instance(1, values);
                newInstance = new MyInstance(newInstance, instance.getTime());
                cInstances.add(newInstance);
            }

            return cInstances;
        }

    }

    public class EqualFreqDiscretizer extends Discretizer {

        /* (non-Javadoc)
         * @see net.sf.bddbddb.OrderClassifier.Discretizer#discretize(java.util.SortedSet)
         */
        public Instances discretize(SortedSet orders) {
            Instances cInstances = timeInstances(orders);
            PKIDiscretize d = new PKIDiscretize();
            int[] arr = new int[1];
            arr[0] = 0;
            d.setAttributeIndicesArray(arr);
            try {
                d.setInputFormat(cInstances);
                cInstances = Filter.useFilter(cInstances, d);
            } catch (Exception e1) {
                e1.printStackTrace();
            }
            System.out.println("num bins: " + d.getBins());
            setNumClasses((int) Math.sqrt(cInstances.numInstances()));

            Instances newInstances = makeInstances();
            Enumeration e = cInstances.enumerateInstances();
            for (Iterator it = orders.iterator(); it.hasNext();) {
                MyInstance instance = (MyInstance) it.next();
                Instance cInstance = (Instance) e.nextElement();

                instance.setDataset(newInstances);
                instance.setClassValue(cInstance.value(0));
                //  System.out.println("instance: " + instance + " time: " + instance.getTime());
                newInstances.add(instance);

            }
            System.out.println("num instances: " + newInstances.numInstances());
            return newInstances;
        }

    }
    /*
     public class MDLDiscretizer extends Discretizer{
     public Instances discretize(SortedSet orders){
         
     }
     }*/

    public class ClusterDiscretizer extends Discretizer {
        public Instances discretize(SortedSet orders) {
            Instances newInstances = null;
            //  return discretize((MyInstance[]) orders.toArray(new MyInstance[orders.size()]), 0, orders.size() -1);
            // double[] times = new double[instances.numInstances()];
            Clusterer clusterer = new EM(); //new Cobweb();
            try {

                /*FastVector attributes = new FastVector(classifier.getAttributes().size());
                 attributes.appendElements(classifier.getAttributes());
                 attributes.removeElementAt(classifier.getAttributes().size() - 1);
                 */
                Instances cInstances = timeInstances(orders);
                //data = new Instances(data);
                /*      int index = data.classIndex();
                 data.setClassIndex(-1);
                 data.deleteAttributeAt(index);
                 data.
                 */
                ((EM) clusterer).setMinStdDev(.1);
                clusterer.buildClusterer(cInstances);
                setNumClasses(clusterer.numberOfClusters());
                double[][][] clusterInfo = ((EM) clusterer).getClusterModelsNumericAtts();
                int[] clusterMap = new int[clusterInfo.length];
                Pair[] clusts = new Pair[clusterInfo.length];
                for (int i = 0; i < clusterInfo.length; ++i) {
                    clusts[i] = new Pair(new Integer(i), new Double(clusterInfo[i][0][0]));
                }

                Arrays.sort(clusts, new Comparator() {

                    public int compare(Object o1, Object o2) {
                        Pair p1 = (Pair) o1;
                        Pair p2 = (Pair) o2;

                        return Double.compare(((Double) p1.right).doubleValue(), ((Double) p2.right).doubleValue());
                    }
                });

                for (int i = 0; i < clusts.length; ++i) {
                    Pair p = clusts[i];
                    clusterMap[((Integer) p.left).intValue()] = i;
                }

                // System.out.println(clusterer.toString());
                System.out.println("number of clusters: " + clusterer.numberOfClusters());
                Pair[] ins = new Pair[orders.size()];

                int j = 0;
                // for(Iterator it = classifier.getOrders().iterator(); it.hasNext();){

                newInstances = makeInstances();
                MyInstance[] orderInstances = (MyInstance[]) orders.toArray(new MyInstance[orders.size()]);
                Assert._assert(cInstances.numInstances() == orderInstances.length);
                Assert._assert(newInstances.numInstances() == 0);
                for (int i = 0; i < cInstances.numInstances(); ++i) {
                    Instance cInstance = cInstances.instance(i);
                    MyInstance instance = orderInstances[i];
                    Assert._assert((long) cInstance.value(0) == instance.getTime());
                    int cnum = -1;
                    try {
                        cnum = clusterer.clusterInstance(cInstance);
                        Assert._assert(cnum >= 0 && cnum < clusterer.numberOfClusters());
                        cnum = clusterMap[cnum];
                        Assert._assert(cnum >= 0 && cnum < clusterer.numberOfClusters());

                        instance.setDataset(newInstances);
                        double val = instance.classValue();
                        instance.setClassValue(Integer.toString(cnum));
                        newInstances.add(instance);
                        //  System.out.println("cnum: " + cnum + " instance: " + instance + " old value: "  + val + " new value: " + instance.classValue());
                        Assert._assert(instance.classValue() < clusterer.numberOfClusters(), instance.toString());
                    } catch (Exception ex) {
                        ex.printStackTrace();
                        Assert.UNREACHABLE("");
                    }
                    ins[j++] = new Pair(cInstance, new Integer(cnum));
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
            /*   for(Enumeration e = newInstances.enumerateInstances(); e.hasMoreElements(); ){
            System.out.println(e.nextElement());
               }*/
            return newInstances;
        }
    }

    public class MedianDiscretizer extends Discretizer {

        /* (non-Javadoc)
         * @see net.sf.bddbddb.OrderClassifier.Discretizer#discretize(java.util.SortedSet)
         */
        public Instances discretize(SortedSet orders) {

            setNumClasses(2);
            MyInstance[] instances = (MyInstance[]) orders.toArray(new MyInstance[orders.size()]);
            Instances newInstances = makeInstances();

            int right = orders.size() - 1;
            int left = 0;
            int stride = right - left;
            int median = stride / 2 + left;

            while (instances[median].time == max_wait_time && median > 0)
                --median;

            for (int i = left; i < median; i++) {
                MyInstance instance = instances[i];
                newInstances.add(instance);
                instance.setDataset(newInstances);

                instance.setClassValue("0");

            }

            for (int i = median; i <= right; i++) {
                MyInstance instance = instances[i];
                newInstances.add(instance);
                instance.setDataset(newInstances);
                instance.setClassValue("1");

            }

            return newInstances;
        }
    }

    public long medianTime() {
        if (orders.size() == 0)
            return max_wait_time;

        Object[] instances = orders.toArray();
        return ((MyInstance) instances[instances.length / 2]).time;
    }

    public String classify(Instance instance) {
        try {
            double predicted = classifier.classifyInstance(instance);

            return data.classAttribute().value((int) predicted);
        } catch (Exception e) {
            e.printStackTrace();
            Assert.UNREACHABLE("Could not classify instance");
        }
        return "";
    }

    public String classify(String varOrder) {
        Instance instance = makeInstance(varOrder);
        instance.setDataset(data);
        return classify(instance);
    }

    public void addOrder(String varOrder, long time) {
        Instance core = makeInstance(varOrder);
        MyInstance instance = new MyInstance(core, time);
        orders.add(instance);

    }

    public FastVector getAttributes() {
        return attributes;
    }

    static class MyAttribute extends weka.core.Attribute {
        /**
         * Version ID for serialization.
         */
        private static final long serialVersionUID = 3546927987002716980L;

        Pair pair;
        /*
        public MyAttribute(Pair pair, FastVector options, int index){
        super(pair.toString(),options, index);
        this.pair = pair;
        }
        */

        public MyAttribute(Pair pair, FastVector options) {
            super(pair.toString(), options);
            this.pair = pair;
        }
    }

    class MyInstance extends Instance implements Comparable {
        /**
         * Version ID for serialization.
         */
        private static final long serialVersionUID = 3257006536231237427L;

        private long time;
        private String order;

        public MyInstance(int numAttributes) {
            super(numAttributes);
            this.time = max_wait_time;
        }

        public MyInstance(MyInstance instance) {
            super(instance);
            this.time = instance.getTime();
        }

        public MyInstance(Instance instance, long time) {
            super(instance);
            this.time = time;
        }

        public Object copy() {
            Instance instance = (Instance) super.copy();
            MyInstance result = new MyInstance(instance, time);
            return result;
        }

        /* (non-Javadoc)
         * @see java.lang.Comparable#compareTo(java.lang.Object)
         */
        public int compareTo(Object o) {
            MyInstance that = (MyInstance) o;
            if (this.time == that.time)
                return this.hashCode() - that.hashCode();
            return (int) ((this.time - that.time) / Math.abs(this.time - that.time));
        }

        /**
         * @return  time for this instance
         */
        public long getTime() {
            return time;
        }

    }

    public Instance makeInstance(String varOrder) {
        Instance instance = new MyInstance(attributes.size());
        for (int i = 0; i < attributes.size() - 1; i++) {
            instance.setValue((weka.core.Attribute) attributes.elementAt(i), ">");
        }
        StringTokenizer st = new StringTokenizer(varOrder, "_");
        List lefts = new LinkedList();
        while (st.hasMoreTokens()) {
            String pdomain = st.nextToken();
            // System.out.println("pdomain: " + pdomain);
            StringTokenizer st2 = new StringTokenizer(pdomain, "x");
            if (st2.countTokens() > 1) {
                List interLefts = new LinkedList();
                while (st2.hasMoreTokens()) {
                    String idomain = st2.nextToken();
                    //   System.out.println("idomain: " + idomain);

                    BDDDomain d = rule.solver.getBDDDomain(idomain);
                    if (!domainSet.contains(d))
                        continue;
                    Collection pairs = domainToAttrPairs.getValues(d);
                    //  System.out.println("interlefts: " + interLefts);
                    //  System.out.println("pairs: " + pairs);

                    pair(instance, interLefts, pairs, "=", true);
                    interLefts.addAll(pairs);

                    pair(instance, lefts, pairs, "<", false);
                    lefts.addAll(pairs);
                }

                continue;
            }
            BDDDomain d = rule.solver.getBDDDomain(pdomain);
            if (!domainSet.contains(d))
                continue;
            Collection pairs = domainToAttrPairs.getValues(d);
            pair(instance, lefts, pairs, "<", false);
            lefts.addAll(pairs);

        }
        // System.out.print("attributes:");
        /*   for(int i = 0; i < attributes.size(); i++){
         System.out.print(" " + attributes.elementAt(i));
         }
         */
        //  System.out.println();
        // System.out.println("instance: " + instance);
        return instance;
    }

    private void pair(Instance instance, List lefts, Collection pairs, String value, boolean reverse) {
        for (Iterator it = lefts.iterator(); it.hasNext();) {
            Pair left = (Pair) it.next();
            for (Iterator jt = pairs.iterator(); jt.hasNext();) {
                Pair right = (Pair) jt.next();
                if (right.equals(left))
                    continue;
                Pair p = new Pair(left, right);
                weka.core.Attribute a = new MyAttribute(p, attrOptions);
                int index = attributes.indexOf(a);
                if (index == -1) {
                    if (reverse) {
                        Pair r = new Pair(right, left);
                        a = new MyAttribute(r, attrOptions);
                        index = attributes.indexOf(a);
                    } else
                        continue;
                }

                Assert._assert(index != -1, " no attribute for " + p + " found. attributes: " + attributes);
                //  System.out.println("Setting " + a.toString() + " to " + value);
                instance.setValue((weka.core.Attribute) attributes.elementAt(index), value);
            }
        }
    }

    /**
     * @return  number of attributes
     */
    public int getNumAttributes() {
        return attributes.size() - 1;
    }

    /**
     * @return  data
     */
    public Instances getData() {
        return data;
    }

    /**
     * @return  sorted orders
     */
    public SortedSet getOrders() {
        return orders;
    }
}