weka.core.pmml.PMMLFactory.java Source code

Java tutorial

Introduction

Here is the source code for weka.core.pmml.PMMLFactory.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    PMMLFactory.java
 *    Copyright (C) 2008-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.core.pmml;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;

import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.pmml.consumer.GeneralRegression;
import weka.classifiers.pmml.consumer.NeuralNetwork;
import weka.classifiers.pmml.consumer.PMMLClassifier;
import weka.classifiers.pmml.consumer.Regression;
import weka.classifiers.pmml.consumer.RuleSetModel;
import weka.classifiers.pmml.consumer.SupportVectorMachineModel;
import weka.classifiers.pmml.consumer.TreeModel;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.gui.Logger;

/**
 * This class is a factory class for reading/writing PMML models
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision$
 */
public class PMMLFactory {

    /** for serialization */

    protected enum ModelType {
        UNKNOWN_MODEL("unknown"), REGRESSION_MODEL("Regression"), GENERAL_REGRESSION_MODEL(
                "GeneralRegression"), NEURAL_NETWORK_MODEL("NeuralNetwork"), TREE_MODEL(
                        "TreeModel"), RULESET_MODEL("RuleSetModel"), SVM_MODEL("SupportVectorMachineModel");

        private final String m_stringVal;

        ModelType(String name) {
            m_stringVal = name;
        }

        @Override
        public String toString() {
            return m_stringVal;
        }
    }

    /**
     * Read and return a PMML model.
     * 
     * @param filename the name of the file to read from
     * @return a PMML model
     * @throws Exception if there is a problem while reading the file
     */
    public static PMMLModel getPMMLModel(String filename) throws Exception {
        return getPMMLModel(filename, null);
    }

    /**
     * Read and return a PMML model.
     * 
     * @param file a <code>File</code> to read from
     * @return a PMML model
     * @throws Exception if there is a problem while reading the file
     */
    public static PMMLModel getPMMLModel(File file) throws Exception {
        return getPMMLModel(file, null);
    }

    /**
     * Read and return a PMML model.
     * 
     * @param stream the <code>InputStream</code> to read from
     * @return a PMML model
     * @throws Exception if there is a problem while reading from the stream
     */
    public static PMMLModel getPMMLModel(InputStream stream) throws Exception {
        return getPMMLModel(stream, null);
    }

    /**
     * Read and return a PMML model.
     * 
     * @param filename the name of the file to read from
     * @param log the logging object to use (or null if none is to be used)
     * @return a PMML model
     * @throws Exception if there is a problem while reading the file
     */
    public static PMMLModel getPMMLModel(String filename, Logger log) throws Exception {
        return getPMMLModel(new File(filename), log);
    }

    /**
     * Read and return a PMML model.
     * 
     * @param file a <code>File</code> to read from
     * @param log the logging object to use (or null if none is to be used)
     * @return a PMML model
     * @throws Exception if there is a problem while reading the file
     */
    public static PMMLModel getPMMLModel(File file, Logger log) throws Exception {
        return getPMMLModel(new BufferedInputStream(new FileInputStream(file)), log);
    }

    private static boolean isPMML(Document doc) {
        NodeList tempL = doc.getElementsByTagName("PMML");
        if (tempL.getLength() == 0) {
            return false;
        }

        return true;
    }

    /**
     * Read and return a PMML model.
     * 
     * @param stream the <code>InputStream</code> to read from
     * @param log the logging object to use (or null if none is to be used)
     * @return a PMML model
     * @throws Exception if there is a problem while reading from the stream
     */
    public static PMMLModel getPMMLModel(InputStream stream, Logger log) throws Exception {
        DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
        DocumentBuilder db = dbf.newDocumentBuilder();
        Document doc = db.parse(stream);
        stream.close();
        doc.getDocumentElement().normalize();
        if (!isPMML(doc)) {
            throw new IllegalArgumentException("[PMMLFactory] Source is not a PMML file!!");
        }

        // System.out.println("Root element " +
        // doc.getDocumentElement().getNodeName());

        Instances dataDictionary = getDataDictionaryAsInstances(doc);
        TransformationDictionary transDict = getTransformationDictionary(doc, dataDictionary);

        ModelType modelType = getModelType(doc);
        if (modelType == ModelType.UNKNOWN_MODEL) {
            throw new Exception("Unsupported PMML model type");
        }
        Element model = getModelElement(doc, modelType);

        // Construct mining schema and meta data
        MiningSchema ms = new MiningSchema(model, dataDictionary, transDict);

        // System.out.println(ms);
        // System.exit(1);
        // Instances miningSchema = getMiningSchemaAsInstances(model,
        // dataDictionary);
        PMMLModel theModel = getModelInstance(doc, modelType, model, dataDictionary, ms);
        if (log != null) {
            theModel.setLog(log);
        }
        return theModel;
    }

    /**
     * Get the transformation dictionary (if there is one).
     * 
     * @param doc the Document containing the PMML model
     * @param dataDictionary the data dictionary as an Instances object
     * @return the transformation dictionary or null if there is none defined in
     *         the Document
     * @throws Exception if there is a problem getting the transformation
     *           dictionary
     */
    protected static TransformationDictionary getTransformationDictionary(Document doc, Instances dataDictionary)
            throws Exception {
        TransformationDictionary transDict = null;

        NodeList transL = doc.getElementsByTagName("TransformationDictionary");
        // should be of size 0 or 1
        if (transL.getLength() > 0) {
            Node transNode = transL.item(0);
            if (transNode.getNodeType() == Node.ELEMENT_NODE) {
                transDict = new TransformationDictionary((Element) transNode, dataDictionary);
            }
        }

        return transDict;
    }

    /**
     * Serialize a <code>PMMLModel</code> object that encapsulates a PMML model
     * 
     * @param model the <code>PMMLModel</code> to serialize
     * @param filename the name of the file to save to
     * @throws Exception if something goes wrong during serialization
     */
    public static void serializePMMLModel(PMMLModel model, String filename) throws Exception {
        serializePMMLModel(model, new File(filename));
    }

    /**
     * Serialize a <code>PMMLModel</code> object that encapsulates a PMML model
     * 
     * @param model the <code>PMMLModel</code> to serialize
     * @param file the <code>File</code> to save to
     * @throws Exception if something goes wrong during serialization
     */
    public static void serializePMMLModel(PMMLModel model, File file) throws Exception {
        serializePMMLModel(model, new BufferedOutputStream(new FileOutputStream(file)));
    }

    /**
     * Serialize a <code>PMMLModel</code> object that encapsulates a PMML model
     * 
     * @param model the <code>PMMLModel</code> to serialize
     * @param stream the <code>OutputStream</code> to serialize to
     * @throws Exception if something goes wrong during serialization
     */
    public static void serializePMMLModel(PMMLModel model, OutputStream stream) throws Exception {
        ObjectOutputStream oo = new ObjectOutputStream(stream);
        Instances header = model.getMiningSchema().getFieldsAsInstances();
        oo.writeObject(header);
        oo.writeObject(model);
        oo.flush();
        oo.close();
    }

    /**
     * Get an instance of a PMMLModel from the supplied Document
     * 
     * @param doc the Document holding the pmml
     * @param modelType the type of model
     * @param model the Element encapsulating the model part of the Document
     * @param dataDictionary the data dictionary as an Instances object
     * @param miningSchema the mining schema
     * @return a PMMLModel object
     * @throws Exception if there is a problem constructing the model or if the
     *           model type is not supported
     */
    protected static PMMLModel getModelInstance(Document doc, ModelType modelType, Element model,
            Instances dataDictionary, MiningSchema miningSchema) throws Exception {
        PMMLModel pmmlM = null;
        switch (modelType) {
        case REGRESSION_MODEL:
            pmmlM = new Regression(model, dataDictionary, miningSchema);
            // System.out.println(pmmlM);
            break;
        case GENERAL_REGRESSION_MODEL:
            pmmlM = new GeneralRegression(model, dataDictionary, miningSchema);
            // System.out.println(pmmlM);
            break;
        case NEURAL_NETWORK_MODEL:
            pmmlM = new NeuralNetwork(model, dataDictionary, miningSchema);
            break;
        case TREE_MODEL:
            pmmlM = new TreeModel(model, dataDictionary, miningSchema);
            break;
        case RULESET_MODEL:
            pmmlM = new RuleSetModel(model, dataDictionary, miningSchema);
            break;
        case SVM_MODEL:
            pmmlM = new SupportVectorMachineModel(model, dataDictionary, miningSchema);
            break;
        default:
            throw new Exception("[PMMLFactory] Unknown model type!!");
        }
        pmmlM.setPMMLVersion(doc);
        pmmlM.setCreatorApplication(doc);
        return pmmlM;
    }

    /**
     * Get the type of model
     * 
     * @param doc the Document encapsulating the pmml
     * @return the type of model
     */
    protected static ModelType getModelType(Document doc) {
        NodeList temp = doc.getElementsByTagName("RegressionModel");
        if (temp.getLength() > 0) {
            return ModelType.REGRESSION_MODEL;
        }

        temp = doc.getElementsByTagName("GeneralRegressionModel");
        if (temp.getLength() > 0) {
            return ModelType.GENERAL_REGRESSION_MODEL;
        }

        temp = doc.getElementsByTagName("NeuralNetwork");
        if (temp.getLength() > 0) {
            return ModelType.NEURAL_NETWORK_MODEL;
        }

        temp = doc.getElementsByTagName("TreeModel");
        if (temp.getLength() > 0) {
            return ModelType.TREE_MODEL;
        }

        temp = doc.getElementsByTagName("RuleSetModel");
        if (temp.getLength() > 0) {
            return ModelType.RULESET_MODEL;
        }

        temp = doc.getElementsByTagName("SupportVectorMachineModel");
        if (temp.getLength() > 0) {
            return ModelType.SVM_MODEL;
        }

        return ModelType.UNKNOWN_MODEL;
    }

    /**
     * Get the Element that contains the pmml model
     * 
     * @param doc the Document encapsulating the pmml
     * @param modelType the type of model
     * @throws Exception if the model type is unsupported/unknown
     */
    protected static Element getModelElement(Document doc, ModelType modelType) throws Exception {
        NodeList temp = null;
        Element model = null;
        switch (modelType) {
        case REGRESSION_MODEL:
            temp = doc.getElementsByTagName("RegressionModel");
            break;
        case GENERAL_REGRESSION_MODEL:
            temp = doc.getElementsByTagName("GeneralRegressionModel");
            break;
        case NEURAL_NETWORK_MODEL:
            temp = doc.getElementsByTagName("NeuralNetwork");
            break;
        case TREE_MODEL:
            temp = doc.getElementsByTagName("TreeModel");
            break;
        case RULESET_MODEL:
            temp = doc.getElementsByTagName("RuleSetModel");
            break;
        case SVM_MODEL:
            temp = doc.getElementsByTagName("SupportVectorMachineModel");
            break;
        default:
            throw new Exception("[PMMLFactory] unknown/unsupported model type.");
        }

        if (temp != null && temp.getLength() > 0) {
            Node modelNode = temp.item(0);
            if (modelNode.getNodeType() == Node.ELEMENT_NODE) {
                model = (Element) modelNode;
            }
        }

        return model;
    }

    /**
     * Get the mining schema as an Instances object
     * 
     * @param model the Element containing the pmml model
     * @param dataDictionary the data dictionary as an Instances object
     * @return the mining schema as an Instances object
     * @throws Exception if something goes wrong during reading the mining schema
     * @deprecated Use the MiningSchema class instead
     */
    @Deprecated
    protected static Instances getMiningSchemaAsInstances(Element model, Instances dataDictionary)
            throws Exception {
        ArrayList<Attribute> attInfo = new ArrayList<Attribute>();
        NodeList fieldList = model.getElementsByTagName("MiningField");
        int classIndex = -1;
        int addedCount = 0;
        for (int i = 0; i < fieldList.getLength(); i++) {
            Node miningField = fieldList.item(i);
            if (miningField.getNodeType() == Node.ELEMENT_NODE) {
                Element miningFieldEl = (Element) miningField;
                String name = miningFieldEl.getAttribute("name");
                String usage = miningFieldEl.getAttribute("usageType");
                // TO-DO: also missing value replacement etc.

                // find this attribute in the dataDictionary
                Attribute miningAtt = dataDictionary.attribute(name);
                if (miningAtt != null) {
                    if (usage.length() == 0 || usage.equals("active") || usage.equals("predicted")) {
                        attInfo.add(miningAtt);
                        addedCount++;
                    }
                    if (usage.equals("predicted")) {
                        classIndex = addedCount - 1;
                    }
                } else {
                    throw new Exception("Can't find mining field: " + name + " in the data dictionary.");
                }
            }
        }

        Instances insts = new Instances("miningSchema", attInfo, 0);
        // System.out.println(insts);
        if (classIndex != -1) {
            insts.setClassIndex(classIndex);
        }

        return insts;
    }

    /**
     * Get the data dictionary as an Instances object
     * 
     * @param doc the Document encapsulating the pmml
     * @return the data dictionary as an Instances object
     * @throws Exception if there are fields that are not continuous, ordinal or
     *           categorical in the data dictionary
     */
    protected static Instances getDataDictionaryAsInstances(Document doc) throws Exception {

        // TO-DO: definition of missing values (see below)

        ArrayList<Attribute> attInfo = new ArrayList<Attribute>();
        NodeList dataDictionary = doc.getElementsByTagName("DataField");
        for (int i = 0; i < dataDictionary.getLength(); i++) {
            Node dataField = dataDictionary.item(i);
            if (dataField.getNodeType() == Node.ELEMENT_NODE) {
                Element dataFieldEl = (Element) dataField;
                String name = dataFieldEl.getAttribute("name");
                String type = dataFieldEl.getAttribute("optype");
                Attribute tempAtt = null;
                if (name != null && type != null) {
                    if (type.equals("continuous")) {
                        tempAtt = new Attribute(name);
                    } else if (type.equals("categorical") || type.equals("ordinal")) {
                        NodeList valueList = dataFieldEl.getElementsByTagName("Value");
                        if (valueList == null || valueList.getLength() == 0) {
                            // assume that categorical values will be revealed in the actual
                            // model.
                            // Create a string attribute for now
                            ArrayList<String> nullV = null;
                            tempAtt = new Attribute(name, nullV);
                        } else {
                            // add the values (if defined as "valid")
                            ArrayList<String> valueVector = new ArrayList<String>();
                            for (int j = 0; j < valueList.getLength(); j++) {
                                Node val = valueList.item(j);
                                if (val.getNodeType() == Node.ELEMENT_NODE) {
                                    // property is optional (default value is "valid")
                                    String property = ((Element) val).getAttribute("property");
                                    if (property == null || property.length() == 0 || property.equals("valid")) {
                                        String value = ((Element) val).getAttribute("value");
                                        valueVector.add(value);
                                    } else {
                                        // Just ignore invalid or missing value definitions for
                                        // now...
                                        // TO-DO: implement Value meta data with missing/invalid
                                        // value defs.
                                    }
                                }
                            }
                            tempAtt = new Attribute(name, valueVector);
                        }
                    } else {
                        throw new Exception("[PMMLFactory] can't handle " + type + "attributes.");
                    }
                    attInfo.add(tempAtt);
                }
            }
        }

        // TO-DO: check whether certain values are declared to represent
        // missing or invalid values (applies to both categorical and continuous
        // attributes

        // create the Instances structure
        Instances insts = new Instances("dataDictionary", attInfo, 0);
        // System.out.println(insts);

        return insts;
    }

    public static String applyClassifier(PMMLModel model, Instances test) throws Exception {
        StringBuffer buff = new StringBuffer();
        if (!(model instanceof PMMLClassifier)) {
            throw new Exception("PMML model is not a classifier!");
        }

        double[] preds = null;
        PMMLClassifier classifier = (PMMLClassifier) model;
        for (int i = 0; i < test.numInstances(); i++) {
            buff.append("Actual: ");
            Instance temp = test.instance(i);
            if (temp.classAttribute().isNumeric()) {
                buff.append(temp.value(temp.classIndex()) + " ");
            } else {
                buff.append(temp.classAttribute().value((int) temp.value(temp.classIndex())) + " ");
            }
            preds = classifier.distributionForInstance(temp);
            buff.append(" Predicted: ");
            for (double pred : preds) {
                buff.append("" + pred + " ");
            }
            buff.append("\n");
        }
        return buff.toString();
    }

    private static class PMMLClassifierRunner extends AbstractClassifier {

        /** ID added to avoid warning */
        private static final long serialVersionUID = -3742334356788083347L;

        @Override
        public double[] distributionForInstance(Instance test) throws Exception {
            throw new Exception("Don't call this method!!");
        }

        @Override
        public void buildClassifier(Instances instances) throws Exception {
            throw new Exception("Don't call this method!!");
        }

        @Override
        public String getRevision() {
            return weka.core.RevisionUtils.extract("$Revision$");
        }

        public void evaluatePMMLClassifier(String[] options) {
            runClassifier(this, options);
        }
    }

    public static void main(String[] args) {
        try {
            String[] optionsTmp = new String[args.length];
            for (int i = 0; i < args.length; i++) {
                optionsTmp[i] = args[i];
            }
            String pmmlFile = Utils.getOption('l', optionsTmp);
            if (pmmlFile.length() == 0) {
                throw new Exception("[PMMLFactory] must specify a PMML file using the -l option.");
            }
            // see if it is supported before going any further
            getPMMLModel(pmmlFile, null);

            PMMLClassifierRunner pcr = new PMMLClassifierRunner();
            pcr.evaluatePMMLClassifier(args);

            /*
             * PMMLModel model = getPMMLModel(args[0], null);
             * System.out.println(model); if (args.length == 2) { // load an arff file
             * Instances testData = new Instances(new java.io.BufferedReader(new
             * java.io.FileReader(args[1]))); Instances miningSchemaI =
             * model.getMiningSchema().getFieldsAsInstances(); if
             * (miningSchemaI.classIndex() >= 0) { String className =
             * miningSchemaI.classAttribute().name(); for (int i = 0; i <
             * testData.numAttributes(); i++) { if
             * (testData.attribute(i).name().equals(className)) {
             * testData.setClassIndex(i); System.out.println("Found class " +
             * className + " in test data."); break; } } }
             * System.out.println(applyClassifier(model, testData)); }
             */
        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}