OAT.trading.classification.Weka.java Source code

Java tutorial

Introduction

Here is the source code for OAT.trading.classification.Weka.java

Source

/*
Open Auto Trading : A fully automatic equities trading platform with machine learning capabilities
Copyright (C) 2015 AnyObject Ltd.
    
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
    
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.
    
You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package OAT.trading.classification;

import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import OAT.util.GeneralUtil;
import OAT.util.TextUtil;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.*;

/**
 * Weka 3.6.4.
 *
 * <p> Mark Hall, Eibe Frank, Geoffrey Holmes, Bernhard Pfahringer, Pe- ter
 * Reutemann, Ian H. Witten (2009); The WEKA Data Mining Software: An Update;
 * SIGKDD Explorations, Volume 11, Issue 1.
 *
 * @author Antonio Yip
 */
public abstract class Weka extends AbstractPredictor {

    private Classifier classifier;
    private FastVector attributes;
    private FastVector classes;

    /**
     * Returns the class name of the classifier.
     *
     * @return
     */
    protected abstract String getClassifierType();

    /**
     * Returns a list of options.
     *
     * @return an array of String
     */
    protected abstract String[] getOptions();

    private void initClassifier() {
        try {
            initClassifier((Classifier) GeneralUtil.forName(Classifier.class, getClassifierType()));
        } catch (Exception ex) {
            log(Level.SEVERE, null, ex);
        }
    }

    private void initClassifier(Classifier classifier) {
        if (classifier == null) {
            return;
        }

        this.classifier = classifier;

        try {
            classifier.setOptions(getOptions());
        } catch (Exception ex) {
            log(Level.SEVERE, null, ex);
        }

        if (!isCrossValidating()) {
            log(Level.INFO, "Classifier: {0} {1}", new Object[] { classifier.getClass().getSimpleName(),
                    TextUtil.toString(classifier.getOptions()) });
        }

        attributes = new FastVector();
        for (String string : getAttributeNames()) {
            attributes.addElement(new Attribute(string));
        }

        classes = new FastVector();
        classes.addElement("0");
        classes.addElement("1");

        attributes.addElement(new Attribute(DEFAULT_CLASS_COLUMN, classes));
    }

    @Override
    public void loadModel(String file) {
        try {
            classifier = (Classifier) SerializationHelper.read(file);
        } catch (Exception ex) {
            log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void saveModel(String file) {
        try {
            SerializationHelper.write(file, classifier);
        } catch (Exception ex) {
            log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public Prediction predict(InputSample input) {
        if (classifier == null) {
            log(Level.WARNING, "null classifier");
            return null;
        }

        Instances data = getInstances(input);

        if (data == null) {
            log(Level.WARNING, "null data");
            return null;
        }

        if (!isCrossValidating()) {
            if (isLoggable(Level.FINER)) {
                log(Level.FINER, data.toString());
            }
        }

        try {
            double output = new Evaluation(data).evaluateModelOnce(classifier, data.firstInstance());

            return Prediction.valueOf(output < 0.5 ? -1 : 1);
        } catch (Exception ex) {
            log(Level.SEVERE, null, ex);
        }

        return null;
    }

    @Override
    public void train(List<TrainingSample> trainingSet) {
        initClassifier();

        if (classifier == null) {
            log(Level.WARNING, "null classifier");
            return;
        }

        Instances data = getInstances(trainingSet);

        if (data == null) {
            log(Level.WARNING, "null data");
            return;
        }

        if (!isCrossValidating()) {
            log(Level.FINE, "Training set size: {0}", data.numInstances());

            if (isLoggable(Level.FINER)) {
                log(Level.FINER, data.toString());
            }
        }

        try {
            classifier.buildClassifier(data);

        } catch (UnsupportedAttributeTypeException ex) {
            log(Level.WARNING, "{1}\nCapabilities: {0}",
                    new Object[] { ex.getMessage(), classifier.getCapabilities() });

        } catch (Exception ex) {
            log(Level.SEVERE, null, ex);
        }
    }

    private Instances getInstances(List<TrainingSample> trainingSet) {
        Instances data = new Instances("trainingSet", attributes, 0);

        for (TrainingSample trainingSample : trainingSet) {
            double[] vars = Arrays.copyOf(trainingSample.getInputVector(), attributes.size());

            int classIndex = attributes.size() - 1;
            vars[classIndex] = (Double) trainingSample.getDesiredOutput() < 0.5 ? classes.indexOf("0")
                    : classes.indexOf("1");

            data.add(new Instance(1.0, vars));
        }

        data.setClassIndex(attributes.size() - 1);

        return data;
    }

    private Instances getInstances(InputSample input) {
        Instances data = new Instances("inputSet", attributes, 0);

        double[] vars = Arrays.copyOf(input.getInputVector(), attributes.size());

        int classIndex = attributes.size() - 1;
        vars[classIndex] = classes.indexOf("0");

        data.add(new Instance(1.0, vars));

        data.setClassIndex(attributes.size() - 1);

        return data;
    }
}