com.rapidminer.tools.WekaInstancesAdaptor.java Source code

Java tutorial

Introduction

Here is the source code for com.rapidminer.tools.WekaInstancesAdaptor.java

Source

/*
 *  RapidMiner
 *
 *  Copyright (C) 2001-2008 by Rapid-I and the contributors
 *
 *  Complete list of developers available at our web site:
 *
 *       http://rapid-i.com
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Affero General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Affero General Public License for more details.
 *
 *  You should have received a copy of the GNU Affero General Public License
 *  along with this program.  If not, see http://www.gnu.org/licenses/.
 */
package com.rapidminer.tools;

import java.util.Enumeration;
import java.util.Iterator;

import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.example.Statistics;
import com.rapidminer.operator.OperatorException;

/**
 * This class extends the Weka class Instances and overrides all methods needed
 * to directly use a RapidMiner {@link ExampleSet} as source for Weka instead of
 * copying the complete data.
 * 
 * @author Ingo Mierswa
 * @version $Id: WekaInstancesAdaptor.java,v 2.13 2006/04/06 15:23:38
 *          ingomierswa Exp $
 */
public class WekaInstancesAdaptor extends Instances {

    private static final long serialVersionUID = 99943154106235423L;

    public static final int LEARNING = 0;

    public static final int PREDICTING = 1;

    public static final int CLUSTERING = 2;

    public static final int ASSOCIATION_RULE_MINING = 3;

    public static final int WEIGHTING = 4;

    /**
     * This enumeration implementation uses an ExampleReader (Iterator) to enumerate the
     * instances.
     */
    private class InstanceEnumeration implements Enumeration {

        private Iterator<Example> reader;

        public InstanceEnumeration(Iterator<Example> reader) {
            this.reader = reader;
        }

        public Object nextElement() {
            return toWekaInstance(reader.next());
        }

        public boolean hasMoreElements() {
            return reader.hasNext();
        }
    }

    /** The example set which backs up the Instances object. */
    private ExampleSet exampleSet;

    /** This transformation might help to speed up the creation of sparse examples. */
    private transient FastExample2SparseTransform exampleTransform;

    /** The most frequent nominal values (only used for association rule mining, null otherwise).
     *  -1 if attribute is numerical. */
    private int[] mostFrequent = null;

    /**
     * The task type for which this instances object is used. Must be one out of
     * LEARNING, PREDICTING, CLUSTERING, ASSOCIATION_RULE_MINING, or WEIGHTING. For the
     * latter cases the original label attribute will be omitted.
     */
    private int taskType = LEARNING;

    /** The label attribute or null if not desired (depending on task). */
    private Attribute labelAttribute = null;

    /** The weight attribute or null if not available. */
    private Attribute weightAttribute = null;

    /** Creates a new Instances object based on the given example set. */
    public WekaInstancesAdaptor(String name, ExampleSet exampleSet, int taskType) throws OperatorException {
        super(name, getAttributeVector(exampleSet, taskType), 0);
        this.exampleSet = exampleSet;
        this.taskType = taskType;
        this.weightAttribute = exampleSet.getAttributes().getWeight();
        this.exampleTransform = new FastExample2SparseTransform(exampleSet);
        switch (taskType) {
        case LEARNING:
            labelAttribute = exampleSet.getAttributes().getLabel();
            setClassIndex(exampleSet.getAttributes().size());
            break;
        case PREDICTING:
            labelAttribute = exampleSet.getAttributes().getPredictedLabel();
            setClassIndex(exampleSet.getAttributes().size());
            break;
        case CLUSTERING:
            labelAttribute = null;
            setClassIndex(-1);
            break;
        case ASSOCIATION_RULE_MINING:
            // in case of association learning the most frequent attribute
            // is needed to set to "unknown"
            exampleSet.recalculateAllAttributeStatistics();
            this.mostFrequent = new int[exampleSet.getAttributes().size()];
            int i = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                if (attribute.isNominal()) {
                    this.mostFrequent[i] = (int) exampleSet.getStatistics(attribute, Statistics.MODE);
                } else {
                    this.mostFrequent[i] = -1;
                }
                i++;
            }
            labelAttribute = null;
            setClassIndex(-1);
            break;
        case WEIGHTING:
            labelAttribute = exampleSet.getAttributes().getLabel();
            if (labelAttribute != null)
                setClassIndex(exampleSet.getAttributes().size());
            break;
        }
    }

    protected Object readResolve() {
        try {
            this.exampleTransform = new FastExample2SparseTransform(this.exampleSet);
        } catch (OperatorException e) {
            // do nothing
        }
        return this;
    }

    // ================================================================================
    // Overriding some Weka methods
    // ================================================================================

    /** Returns an instance enumeration based on an ExampleReader. */
    public Enumeration enumerateInstances() {
        return new InstanceEnumeration(exampleSet.iterator());
    }

    /** Returns the i-th instance. */
    public Instance instance(int i) {
        return toWekaInstance(exampleSet.getExample(i));
    }

    /** Returns the number of instances. */
    public int numInstances() {
        return exampleSet.size();
    }

    // ================================================================================
    // Transforming examples into Weka instances
    // ================================================================================

    /** Gets an example and creates a Weka instance. */
    private Instance toWekaInstance(Example example) {
        int numberOfRegularValues = example.getAttributes().size();
        int numberOfValues = numberOfRegularValues + (labelAttribute != null ? 1 : 0);
        double[] values = new double[numberOfValues];

        // set regular attribute values
        if (taskType == ASSOCIATION_RULE_MINING) {
            int a = 0;
            for (Attribute attribute : exampleSet.getAttributes()) {
                double value = example.getValue(attribute);
                if (attribute.isNominal()) {
                    if (value == mostFrequent[a])
                        value = Double.NaN;
                    // sets the most frequent value to missing
                    // for association learning
                }
                values[a] = value;
                a++;
            }
        } else {
            int[] nonDefaultIndices = exampleTransform.getNonDefaultAttributeIndices(example);
            double[] nonDefaultValues = exampleTransform.getNonDefaultAttributeValues(example, nonDefaultIndices);
            for (int a = 0; a < nonDefaultIndices.length; a++) {
                values[nonDefaultIndices[a]] = nonDefaultValues[a];
            }
        }

        // set label value if necessary
        switch (taskType) {
        case LEARNING:
            values[values.length - 1] = example.getValue(labelAttribute);
            break;
        case PREDICTING:
            values[values.length - 1] = Double.NaN;
            break;
        case WEIGHTING:
            if (labelAttribute != null)
                values[values.length - 1] = example.getValue(labelAttribute);
            break;
        default:
            break;
        }

        // get instance weight
        double weight = 1.0d;
        if (this.weightAttribute != null)
            weight = example.getValue(this.weightAttribute);

        // create new instance
        Instance instance = new Instance(weight, values);
        instance.setDataset(this);
        return instance;
    }

    // ================================================================================

    private static FastVector getAttributeVector(ExampleSet exampleSet, int taskType) {
        // determine label
        Attribute label = null;
        switch (taskType) {
        case LEARNING:
        case WEIGHTING:
            label = exampleSet.getAttributes().getLabel();
            break;
        case PREDICTING:
            label = exampleSet.getAttributes().getPredictedLabel();
            break;
        default:
            break;
        }

        // add regular attributes
        FastVector attributeVector = new FastVector(exampleSet.getAttributes().size() + (label != null ? 1 : 0));
        for (Attribute attribute : exampleSet.getAttributes()) {
            attributeVector.addElement(toWekaAttribute(attribute));
        }

        // add label
        if (label != null)
            attributeVector.addElement(toWekaAttribute(label));

        return attributeVector;
    }

    /** Converts an {@link Attribute} to a Weka attribute. */
    private static weka.core.Attribute toWekaAttribute(Attribute attribute) {
        if (attribute == null)
            return null;
        weka.core.Attribute result = null;
        if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(attribute.getValueType(), Ontology.NOMINAL)) {
            FastVector nominalValues = new FastVector(attribute.getMapping().getValues().size());
            for (int i = 0; i < attribute.getMapping().getValues().size(); i++) {
                nominalValues.addElement(attribute.getMapping().mapIndex(i));
            }
            result = new weka.core.Attribute(attribute.getName(), nominalValues);
        } else {
            result = new weka.core.Attribute(attribute.getName());
        }
        return result;
    }
}