org.pentaho.di.scoring.WekaScoringData.java Source code

Java tutorial

Introduction

Here is the source code for org.pentaho.di.scoring.WekaScoringData.java

Source

/*
* This program is free software; you can redistribute it and/or modify it under the
* terms of the GNU General Public License, version 2 as published by the Free Software
* Foundation.
*
* You should have received a copy of the GNU General Public License along with this
* program; if not, you can obtain a copy at http://www.gnu.org/licenses/gpl-2.0.html
* or from the Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*
*
* Copyright 2006 - 2013 Pentaho Corporation.  All rights reserved.
*/

package org.pentaho.di.scoring;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Vector;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import org.apache.commons.vfs.FileObject;
import org.pentaho.di.core.logging.LogChannelInterface;
import org.pentaho.di.core.row.RowDataUtil;
import org.pentaho.di.core.row.RowMetaInterface;
import org.pentaho.di.core.row.ValueMetaInterface;
import org.pentaho.di.core.variables.VariableSpace;
import org.pentaho.di.core.vfs.KettleVFS;
import org.pentaho.di.i18n.BaseMessages;
import org.pentaho.di.trans.step.BaseStepData;
import org.pentaho.di.trans.step.StepDataInterface;

import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.pmml.PMMLFactory;
import weka.core.pmml.PMMLModel;
import weka.core.xml.XStream;

/**
 * Holds temporary data and has routines for loading serialized models.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}org)
 */
public class WekaScoringData extends BaseStepData implements StepDataInterface {

    /**
     * some constants for various input field - attribute match/type problems
     */
    public static final int NO_MATCH = -1;
    public static final int TYPE_MISMATCH = -2;

    /** the output data format */
    protected RowMetaInterface m_outputRowMeta;

    /** holds values for instances constructed for prediction */
    private double[] m_vals = null;

    /**
     * Holds the actual Weka model (classifier, clusterer or PMML) used by this
     * copy of the step
     */
    protected WekaScoringModel m_model;

    /**
     * Holds a default model - only used when model files are sourced from a field
     * in the incoming data rows. In this case, it is the fallback model if there
     * is no model file specified in the incoming row.
     */
    protected WekaScoringModel m_defaultModel;

    /** used to map attribute indices to incoming field indices */
    private int[] m_mappingIndexes;

    /** whether to update the model (if incremental) */
    protected boolean m_updateIncrementalModel = false;

    public WekaScoringData() {
        super();
    }

    /**
     * Set the model for this copy of the step to use
     * 
     * @param model the model to use
     */
    public void setModel(WekaScoringModel model) {
        m_model = model;
    }

    /**
     * Get the model that this copy of the step is using
     * 
     * @return the model that this copy of the step is using
     */
    public WekaScoringModel getModel() {
        return m_model;
    }

    /**
     * Set the default model for this copy of the step to use. This gets used if
     * we are getting model file paths from a field in the incoming row structure
     * and a given row has null for the model path.
     * 
     * @param model the model to use as fallback
     */
    public void setDefaultModel(WekaScoringModel model) {
        m_defaultModel = model;
    }

    /**
     * Get the default model for this copy of the step to use. This gets used if
     * we are getting model file paths from a field in the incoming row structure
     * and a given row has null for the model path.
     * 
     * @return the model to use as fallback
     */
    public WekaScoringModel getDefaultModel() {
        return m_defaultModel;
    }

    /**
     * Get the meta data for the output format
     * 
     * @return a <code>RowMetaInterface</code> value
     */
    public RowMetaInterface getOutputRowMeta() {
        return m_outputRowMeta;
    }

    /**
     * Set the meta data for the output format
     * 
     * @param rmi a <code>RowMetaInterface</code> value
     */
    public void setOutputRowMeta(RowMetaInterface rmi) {
        m_outputRowMeta = rmi;
    }

    /**
     * Finds a mapping between the attributes that a Weka model has been trained
     * with and the incoming Kettle row format. Returns an array of indices, where
     * the element at index 0 of the array is the index of the Kettle field that
     * corresponds to the first attribute in the Instances structure, the element
     * at index 1 is the index of the Kettle fields that corresponds to the second
     * attribute, ...
     * 
     * @param header the Instances header
     * @param inputRowMeta the meta data for the incoming rows
     * @param updateIncrementalModel true if the model is incremental and should
     *          be updated on the incoming instances
     * @param log the log to use
     */
    public void mapIncomingRowMetaData(Instances header, RowMetaInterface inputRowMeta,
            boolean updateIncrementalModel, LogChannelInterface log) {
        m_mappingIndexes = WekaScoringData.findMappings(header, inputRowMeta);
        m_updateIncrementalModel = updateIncrementalModel;

        // If updating of incremental models has been selected, then
        // check on the ability to do this
        if (m_updateIncrementalModel && m_model.isSupervisedLearningModel()) {
            if (m_model.isUpdateableModel()) {
                // Do we have the class mapped successfully to an incoming
                // Kettle field
                if (m_mappingIndexes[header.classIndex()] == WekaScoringData.NO_MATCH
                        || m_mappingIndexes[header.classIndex()] == WekaScoringData.TYPE_MISMATCH) {
                    m_updateIncrementalModel = false;
                    log.logError(
                            BaseMessages.getString(WekaScoringMeta.PKG, "WekaScoringMeta.Log.NoMatchForClass")); //$NON-NLS-1$
                }
            } else {
                m_updateIncrementalModel = false;
                log.logError(BaseMessages.getString(WekaScoringMeta.PKG, "WekaScoringMeta.Log.ModelNotUpdateable")); //$NON-NLS-1$
            }
        }
    }

    public static boolean modelFileExists(String modelFile, VariableSpace space) throws Exception {

        modelFile = space.environmentSubstitute(modelFile);
        FileObject modelF = KettleVFS.getFileObject(modelFile);

        return modelF.exists();
    }

    /**
     * Loads a serialized model. Models can either be binary serialized Java
     * objects, objects deep-serialized to xml, or PMML.
     * 
     * @param modelFile a <code>File</code> value
     * @return the model
     * @throws Exception if there is a problem laoding the model.
     */
    public static WekaScoringModel loadSerializedModel(String modelFile, LogChannelInterface log,
            VariableSpace space) throws Exception {

        Object model = null;
        Instances header = null;
        int[] ignoredAttsForClustering = null;

        modelFile = space.environmentSubstitute(modelFile);
        FileObject modelF = KettleVFS.getFileObject(modelFile);
        if (!modelF.exists()) {
            throw new Exception(BaseMessages.getString(WekaScoringMeta.PKG,
                    "WekaScoring.Error.NonExistentModelFile", space.environmentSubstitute(modelFile))); //$NON-NLS-1$
        }

        InputStream is = KettleVFS.getInputStream(modelF);
        BufferedInputStream buff = new BufferedInputStream(is);

        if (modelFile.toLowerCase().endsWith(".xml")) { //$NON-NLS-1$
            // assume it is PMML
            model = PMMLFactory.getPMMLModel(buff, null);

            // we will use the mining schema as the instance structure
            header = ((PMMLModel) model).getMiningSchema().getMiningSchemaAsInstances();

            buff.close();
        } else if (modelFile.toLowerCase().endsWith(".xstreammodel")) { //$NON-NLS-1$
            log.logBasic(BaseMessages.getString(WekaScoringMeta.PKG, "WekaScoringData.Log.LoadXMLModel")); //$NON-NLS-1$

            if (XStream.isPresent()) {
                Vector v = (Vector) XStream.read(buff);

                model = v.elementAt(0);
                if (v.size() == 2) {
                    // try and grab the header
                    header = (Instances) v.elementAt(1);
                }
                buff.close();
            } else {
                buff.close();
                throw new Exception(
                        BaseMessages.getString(WekaScoringMeta.PKG, "WekaScoringData.Error.CantLoadXMLModel")); //$NON-NLS-1$
            }
        } else {
            InputStream stream = buff;
            if (modelFile.toLowerCase().endsWith(".gz")) { //$NON-NLS-1$
                stream = new GZIPInputStream(buff);
            }
            ObjectInputStream oi = new ObjectInputStream(stream);

            model = oi.readObject();

            // try and grab the header
            header = (Instances) oi.readObject();

            if (model instanceof weka.clusterers.Clusterer) {
                // try and grab any attributes to be ignored during clustering
                try {
                    ignoredAttsForClustering = (int[]) oi.readObject();
                } catch (Exception ex) {
                    // Don't moan if there aren't any :-)
                }
            }
            oi.close();
        }

        WekaScoringModel wsm = WekaScoringModel.createScorer(model);
        wsm.setHeader(header);
        if (wsm instanceof WekaScoringClusterer && ignoredAttsForClustering != null) {
            ((WekaScoringClusterer) wsm).setAttributesToIgnore(ignoredAttsForClustering);
        }

        wsm.setLog(log);
        return wsm;
    }

    public static void saveSerializedModel(WekaScoringModel wsm, File saveTo) throws Exception {

        Object model = wsm.getModel();
        Instances header = wsm.getHeader();
        OutputStream os = new FileOutputStream(saveTo);

        if (saveTo.getName().toLowerCase().endsWith(".gz")) { //$NON-NLS-1$
            os = new GZIPOutputStream(os);
        }
        ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(os));

        oos.writeObject(model);
        oos.writeObject(header);
        oos.close();
    }

    /**
     * Finds a mapping between the attributes that a Weka model has been trained
     * with and the incoming Kettle row format. Returns an array of indices, where
     * the element at index 0 of the array is the index of the Kettle field that
     * corresponds to the first attribute in the Instances structure, the element
     * at index 1 is the index of the Kettle fields that corresponds to the second
     * attribute, ...
     * 
     * @param header the Instances header
     * @param inputRowMeta the meta data for the incoming rows
     * @return the mapping as an array of integer indices
     */
    public static int[] findMappings(Instances header, RowMetaInterface inputRowMeta) {
        // Instances header = m_model.getHeader();
        int[] mappingIndexes = new int[header.numAttributes()];

        HashMap<String, Integer> inputFieldLookup = new HashMap<String, Integer>();
        for (int i = 0; i < inputRowMeta.size(); i++) {
            ValueMetaInterface inField = inputRowMeta.getValueMeta(i);
            inputFieldLookup.put(inField.getName(), Integer.valueOf(i));
        }

        // check each attribute in the header against what is incoming
        for (int i = 0; i < header.numAttributes(); i++) {
            Attribute temp = header.attribute(i);
            String attName = temp.name();

            // look for a matching name
            Integer matchIndex = inputFieldLookup.get(attName);
            boolean ok = false;
            int status = NO_MATCH;
            if (matchIndex != null) {
                // check for type compatibility
                ValueMetaInterface tempField = inputRowMeta.getValueMeta(matchIndex.intValue());
                if (tempField.isNumeric() || tempField.isBoolean()) {
                    if (temp.isNumeric()) {
                        ok = true;
                        status = 0;
                    } else {
                        status = TYPE_MISMATCH;
                    }
                } else if (tempField.isString()) {
                    if (temp.isNominal() || temp.isString()) {
                        ok = true;
                        status = 0;
                        // All we can assume is that this input field is ok.
                        // Since we wont know what the possible values are
                        // until the data is pumping throug, we will defer
                        // the matching of legal values until then
                    } else {
                        status = TYPE_MISMATCH;
                    }
                } else {
                    // any other type is a mismatch (might be able to do
                    // something with dates at some stage)
                    status = TYPE_MISMATCH;
                }
            }
            if (ok) {
                mappingIndexes[i] = matchIndex.intValue();
            } else {
                // mark this attribute as missing or type mismatch
                mappingIndexes[i] = status;
            }
        }
        return mappingIndexes;
    }

    /**
     * Generates a batch of predictions (more specifically, an array of output
     * rows containing all input Kettle fields plus new fields that hold the
     * prediction(s)) for each incoming Kettle row given a Weka model.
     * 
     * @param inputMeta the meta data for the incoming rows
     * @param outputMeta the meta data for the output rows
     * @param inputRow the values of the incoming row
     * @param meta meta data for this step
     * @return a Kettle row containing all incoming fields along with new ones
     *         that hold the prediction(s)
     * @exception Exception if an error occurs
     */
    public Object[][] generatePredictions(RowMetaInterface inputMeta, RowMetaInterface outputMeta,
            List<Object[]> inputRows, WekaScoringMeta meta) throws Exception {

        int[] mappingIndexes = m_mappingIndexes;
        WekaScoringModel model = getModel(); // copy of the model for this copy of
                                             // the step
        boolean outputProbs = meta.getOutputProbabilities();
        boolean supervised = model.isSupervisedLearningModel();

        Attribute classAtt = null;
        if (supervised) {
            classAtt = model.getHeader().classAttribute();
        }

        Instances batch = new Instances(model.getHeader(), inputRows.size());
        for (Object[] r : inputRows) {
            Instance inst = constructInstance(inputMeta, r, mappingIndexes, model, true);
            batch.add(inst);
        }

        double[][] preds = model.distributionsForInstances(batch);

        Object[][] result = new Object[preds.length][];
        for (int i = 0; i < preds.length; i++) {
            // First copy the input data to the new result...
            Object[] resultRow = RowDataUtil.resizeArray(inputRows.get(i), outputMeta.size());
            int index = inputMeta.size();

            double[] prediction = preds[i];

            if (prediction.length == 1 || !outputProbs) {
                if (supervised) {
                    if (classAtt.isNumeric()) {
                        Double newVal = new Double(prediction[0]);
                        resultRow[index++] = newVal;
                    } else {
                        int maxProb = Utils.maxIndex(prediction);
                        if (prediction[maxProb] > 0) {
                            String newVal = classAtt.value(maxProb);
                            resultRow[index++] = newVal;
                        } else {
                            String newVal = BaseMessages.getString(WekaScoringMeta.PKG,
                                    "WekaScoringData.Message.UnableToPredict"); //$NON-NLS-1$
                            resultRow[index++] = newVal;
                        }
                    }
                } else {
                    int maxProb = Utils.maxIndex(prediction);
                    if (prediction[maxProb] > 0) {
                        Double newVal = new Double(maxProb);
                        resultRow[index++] = newVal;
                    } else {
                        String newVal = BaseMessages.getString(WekaScoringMeta.PKG,
                                "WekaScoringData.Message.UnableToPredictCluster"); //$NON-NLS-1$
                        resultRow[index++] = newVal;
                    }
                }
            } else {
                // output probability distribution
                for (int j = 0; j < prediction.length; j++) {
                    Double newVal = new Double(prediction[j]);
                    resultRow[index++] = newVal;
                }
            }

            result[i] = resultRow;
        }

        return result;
    }

    /**
     * Generates a prediction (more specifically, an output row containing all
     * input Kettle fields plus new fields that hold the prediction(s)) for an
     * incoming Kettle row given a Weka model.
     * 
     * @param inputMeta the meta data for the incoming rows
     * @param outputMeta the meta data for the output rows
     * @param inputRow the values of the incoming row
     * @param meta meta data for this step
     * @return a Kettle row containing all incoming fields along with new ones
     *         that hold the prediction(s)
     * @exception Exception if an error occurs
     */
    public Object[] generatePrediction(RowMetaInterface inputMeta, RowMetaInterface outputMeta, Object[] inputRow,
            WekaScoringMeta meta) throws Exception {

        int[] mappingIndexes = m_mappingIndexes;
        WekaScoringModel model = getModel();
        boolean outputProbs = meta.getOutputProbabilities();
        boolean supervised = model.isSupervisedLearningModel();

        Attribute classAtt = null;
        if (supervised) {
            classAtt = model.getHeader().classAttribute();
        }

        // need to construct an Instance to represent this
        // input row
        Instance toScore = constructInstance(inputMeta, inputRow, mappingIndexes, model, false);
        double[] prediction = model.distributionForInstance(toScore);

        // Update the model??
        if (meta.getUpdateIncrementalModel() && model.isUpdateableModel()
                && !toScore.isMissing(toScore.classIndex())) {
            model.update(toScore);
        }
        // First copy the input data to the new result...
        Object[] resultRow = RowDataUtil.resizeArray(inputRow, outputMeta.size());
        int index = inputMeta.size();

        // output for numeric class or discrete class value
        if (prediction.length == 1 || !outputProbs) {
            if (supervised) {
                if (classAtt.isNumeric()) {
                    Double newVal = new Double(prediction[0]);
                    resultRow[index++] = newVal;
                } else {
                    int maxProb = Utils.maxIndex(prediction);
                    if (prediction[maxProb] > 0) {
                        String newVal = classAtt.value(maxProb);
                        resultRow[index++] = newVal;
                    } else {
                        String newVal = BaseMessages.getString(WekaScoringMeta.PKG,
                                "WekaScoringData.Message.UnableToPredict"); //$NON-NLS-1$
                        resultRow[index++] = newVal;
                    }
                }
            } else {
                int maxProb = Utils.maxIndex(prediction);
                if (prediction[maxProb] > 0) {
                    Double newVal = new Double(maxProb);
                    resultRow[index++] = newVal;
                } else {
                    String newVal = BaseMessages.getString(WekaScoringMeta.PKG,
                            "WekaScoringData.Message.UnableToPredictCluster"); //$NON-NLS-1$
                    resultRow[index++] = newVal;
                }
            }
        } else {
            // output probability distribution
            for (int i = 0; i < prediction.length; i++) {
                Double newVal = new Double(prediction[i]);
                resultRow[index++] = newVal;
            }
        }

        return resultRow;
    }

    /**
     * Helper method that constructs an Instance to input to the Weka model based
     * on incoming Kettle fields and pre-constructed attribute-to-field mapping
     * data.
     * 
     * @param inputMeta a <code>RowMetaInterface</code> value
     * @param inputRow an <code>Object</code> value
     * @param mappingIndexes an <code>int</code> value
     * @param model a <code>WekaScoringModel</code> value
     * @return an <code>Instance</code> value
     */
    private Instance constructInstance(RowMetaInterface inputMeta, Object[] inputRow, int[] mappingIndexes,
            WekaScoringModel model, boolean freshVector) {

        Instances header = model.getHeader();

        // Re-use this array (unless told otherwise) to avoid an object creation
        if (m_vals == null || freshVector) {
            m_vals = new double[header.numAttributes()];
        }

        for (int i = 0; i < header.numAttributes(); i++) {

            if (mappingIndexes[i] >= 0) {
                try {
                    Object inputVal = inputRow[mappingIndexes[i]];

                    Attribute temp = header.attribute(i);
                    ValueMetaInterface tempField = inputMeta.getValueMeta(mappingIndexes[i]);
                    int fieldType = tempField.getType();

                    // Check for missing value (null or empty string)
                    if (tempField.isNull(inputVal)) {
                        m_vals[i] = Utils.missingValue();
                        continue;
                    }

                    switch (temp.type()) {
                    case Attribute.NUMERIC: {
                        if (fieldType == ValueMetaInterface.TYPE_BOOLEAN) {
                            Boolean b = tempField.getBoolean(inputVal);
                            if (b.booleanValue()) {
                                m_vals[i] = 1.0;
                            } else {
                                m_vals[i] = 0.0;
                            }
                        } else if (fieldType == ValueMetaInterface.TYPE_INTEGER) {
                            Long t = tempField.getInteger(inputVal);
                            m_vals[i] = t.longValue();
                        } else {
                            Double n = tempField.getNumber(inputVal);
                            m_vals[i] = n.doubleValue();
                        }
                    }
                        break;
                    case Attribute.NOMINAL: {
                        String s = tempField.getString(inputVal);
                        // now need to look for this value in the attribute
                        // in order to get the correct index
                        int index = temp.indexOfValue(s);
                        if (index < 0) {
                            // set to missing value
                            m_vals[i] = Utils.missingValue();
                        } else {
                            m_vals[i] = index;
                        }
                    }
                        break;
                    case Attribute.STRING: {
                        String s = tempField.getString(inputVal);
                        // Set the attribute in the header to contain just this string value
                        temp.setStringValue(s);
                        m_vals[i] = 0.0;
                        break;
                    }
                    default:
                        m_vals[i] = Utils.missingValue();
                    }
                } catch (Exception e) {
                    m_vals[i] = Utils.missingValue();
                }
            } else {
                // set to missing value
                m_vals[i] = Utils.missingValue();
            }
        }

        Instance newInst = new DenseInstance(1.0, m_vals);
        newInst.setDataset(header);
        return newInst;
    }
}