probcog.J48Reader.java Source code

Java tutorial

Introduction

Here is the source code for probcog.J48Reader.java

Source

/*******************************************************************************
 * Copyright (C) 2009-2012 Stefan Waldherr, Dominik Jain.
 * 
 * This file is part of ProbCog.
 * 
 * ProbCog is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * ProbCog is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package probcog;

import weka.classifiers.trees.J48;

import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

import java.io.IOException;
import java.io.FileNotFoundException;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.HashMap;
import java.util.Map.Entry;

import probcog.srldb.Database;
import probcog.srldb.Object;
import probcog.srldb.datadict.DDAttribute;
import probcog.srldb.datadict.DDException;
import probcog.srldb.datadict.domain.Domain;

public class J48Reader {

    public static void main(String[] args) {
        try {
            String path = "./";
            //String dbdir = path + "zoli4uli4";
            String dbdir = path + args[1];
            J48 j48 = readJ48(dbdir);
            Instances instances = readDB(args[0]);
            //System.out.println(j48.toString());
            for (int i = 0; i < instances.numInstances(); i++) {
                Instance inst = instances.instance(i);
                //System.out.println(instances.toString());
                /*
                double d = j48.classifyInstance(inst);
                System.out.println(inst.attribute(inst.classIndex()).toString());
                inst.setValue(inst.attribute(inst.classIndex()),d);
                System.out.println("Object is classified as "+ inst.toString(inst.attribute(inst.classIndex())));
                */
                //System.out.println(inst);
                double dist[] = j48.distributionForInstance(inst);
                int j = 0;
                for (double d : dist) {
                    Attribute att = instances.attribute(instances.classIndex());
                    String classification = att.value(j);
                    System.out.println(d + "  " + classification);
                    j++;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static J48 readJ48(String dbdir) throws IOException, ClassNotFoundException {
        String path = dbdir + "/pcc.j48";
        //System.out.println("reading tree " + path);
        ObjectInputStream objstream = new ObjectInputStream(new FileInputStream(path));
        J48 j48 = (J48) objstream.readObject();
        objstream.close();
        return j48;
    }

    public static Instances readDB(String dbname)
            throws IOException, ClassNotFoundException, DDException, FileNotFoundException, Exception {
        Database db = Database.fromFile(new FileInputStream(dbname));
        probcog.srldb.datadict.DataDictionary dd = db.getDataDictionary();
        //the vector of attributes
        FastVector fvAttribs = new FastVector();
        HashMap<String, Attribute> mapAttrs = new HashMap<String, Attribute>();
        for (DDAttribute attribute : dd.getObject("object").getAttributes().values()) {
            if (attribute.isDiscarded() && !attribute.getName().equals("objectT")) {
                continue;
            }
            FastVector attValues = new FastVector();
            Domain dom = attribute.getDomain();
            for (String s : dom.getValues())
                attValues.addElement(s);
            Attribute attr = new Attribute(attribute.getName(), attValues);
            fvAttribs.addElement(attr);
            mapAttrs.put(attribute.getName(), attr);
        }

        Instances instances = new Instances("name", fvAttribs, 10000);
        instances.setClass(mapAttrs.get("objectT"));
        //for each object add an instance
        for (Object o : db.getObjects()) {
            if (o.hasAttribute("objectT")) {
                Instance instance = new Instance(fvAttribs.size());
                for (Entry<String, String> e : o.getAttributes().entrySet()) {
                    if (!dd.getAttribute(e.getKey()).isDiscarded()) {
                        instance.setValue(mapAttrs.get(e.getKey()), e.getValue());
                    }
                }
                instances.add(instance);
            }
        }
        return instances;
    }
}