gate.plugin.learningframework.engines.EngineWeka.java Source code

Java tutorial

Introduction

Here is the source code for gate.plugin.learningframework.engines.EngineWeka.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */

package gate.plugin.learningframework.engines;

import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gate.Annotation;
import gate.AnnotationSet;
import gate.plugin.learningframework.GateClassification;
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
import gate.plugin.learningframework.data.CorpusRepresentationWeka;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.core.Instances;

/**
 *
 * @author Johann Petrak
 */
public class EngineWeka extends Engine {

    @Override
    public Object evaluateHoldout(InstanceList instances, double portion, int repeats, String parms) {
        // Get the parameters 
        // -s/-seed <int> : seed, default 0
        // -S/-nostratify : switch off stratification if we evaluate classification
        Parms opts = new Parms(parms, "s:seed:i", "S:nostratify:b");
        int seed = (int) opts.getValueOrElse("seed", 0);
        if (algorithm instanceof AlgorithmRegression) {
            throw new UnsupportedOperationException("Weka holdout eval for regression not supported yet.");
        } else {
            // must be classification algorithm then!
            weka.core.Instances all = new CorpusRepresentationWeka(corpusRepresentationMallet)
                    .getRepresentationWeka();
            boolean noStratify = (boolean) opts.getValueOrElse("nostratify", 0);
            Random rand = new Random(seed);
            all.randomize(rand);
            boolean stratified = !noStratify;
            // TODO: not sure if/how we can do stratification for holdout evaluation
            // TODO: there must be a better way to do the splitting too!
            // TODO: if there is no better way to split, maybe do out outside for
            // TODO: how to implement repeats?
            if (repeats != 1) {
                throw new GateRuntimeException("Only repeats == 1 supported yet");
            }
            // both regression and classification?
            int trainSize = (int) Math.round(all.numInstances() * portion);
            int testSize = all.numInstances() - trainSize;
            Instances train = new Instances(all, 0, trainSize);
            Instances test = new Instances(all, trainSize, testSize);
            Classifier classifier = (Classifier) trainer;
            try {
                classifier.buildClassifier(train);
            } catch (Exception ex) {
                throw new GateRuntimeException("Error during training of Weka classifier", ex);
            }
            Evaluation eval = null;
            try {
                eval = new Evaluation(train);
            } catch (Exception ex) {
                throw new GateRuntimeException("Could not create Evaluation object", ex);
            }
            try {
                eval.evaluateModel(classifier, test);
            } catch (Exception ex) {
                throw new GateRuntimeException("Error evaluating the classifier", ex);
            }
            System.out.println("Evaluation result:\n" + eval);
            return eval;
        }
    }

    @Override
    public Object evaluateXVal(InstanceList instances, int k, String parms) {
        Parms opts = new Parms(parms, "s:seed:i", "S:nostratify:b");
        int seed = (int) opts.getValueOrElse("seed", 0);
        boolean noStratify = (boolean) opts.getValueOrElse("nostratify", 0);
        Random rand = new Random(seed);
        Instances all = new CorpusRepresentationWeka(corpusRepresentationMallet).getRepresentationWeka();
        Evaluation eval = null;
        try {
            eval = new Evaluation(all);
        } catch (Exception ex) {
            throw new GateRuntimeException("Could not create evaluation object", ex);
        }
        Classifier classifier = (Classifier) trainer;
        try {
            eval.crossValidateModel(classifier, all, k, rand);
        } catch (Exception ex) {
            throw new GateRuntimeException("Error running cross validation", ex);
        }
        System.out.println("Crossvaliation evaluation result:\n" + eval);
        return eval;
    }

    @Override
    protected void loadModel(File directory, String parms) {
        // when this is called, info should already be set
        // we create the instance of the training algorithm from the infor and the instance of
        // the actual trained classifier from de-serialization of the file
        File modelFile = new File(directory, FILENAME_MODEL);
        try {
            ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));
            model = (Classifier) ois.readObject();
            System.out.println("Loaded Weka model " + model.getClass());
            ois.close();
        } catch (Exception e) {
            throw new GateRuntimeException("Could not load Weka model file " + modelFile, e);
        }
        try {
            trainer = Class.forName(info.trainerClass).newInstance();
        } catch (Exception ex) {
            throw new GateRuntimeException("Could not create Weka trainer instance for " + info.trainerClass, ex);
        }
    }

    @Override
    public void trainModel(String parms) {
        // TODO: process parameters!
        if (trainer == null) {
            throw new GateRuntimeException("Cannot train Weka model, not trainer initialized");
        }
        Classifier alg = (Classifier) trainer;
        // convert the Mallet representation to Weka instances
        CorpusRepresentationWeka crw = new CorpusRepresentationWeka(corpusRepresentationMallet);

        try {
            alg.buildClassifier(crw.getRepresentationWeka());
            // set the trained model to the algorithm, in Weka they are identical
            model = alg;
        } catch (Exception ex) {
            throw new GateRuntimeException("Error during training of Weka algorithm " + alg.getClass(), ex);
        }
        updateInfo();
    }

    CorpusRepresentationWeka crWeka;

    /**
     * Return the Weka dataset stored with this engine.
     * This exists mainly so that the engine nows the features, the dataset
     * itself is empty.
     * @return 
     */
    public CorpusRepresentationWeka getCorpusRepresentationWeka() {
        return crWeka;
    }

    @Override
    protected void loadMalletCorpusRepresentation(File directory) {
        corpusRepresentationMallet = CorpusRepresentationMalletTarget.load(directory);
        crWeka = new CorpusRepresentationWeka(corpusRepresentationMallet);
    }

    @Override
    public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet inputAS,
            AnnotationSet sequenceAS, String parms) {

        Instances instances = crWeka.getRepresentationWeka();
        CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget) corpusRepresentationMallet;
        data.stopGrowth();
        List<GateClassification> gcs = new ArrayList<GateClassification>();
        LFPipe pipe = (LFPipe) data.getRepresentationMallet().getPipe();
        Classifier wekaClassifier = (Classifier) model;
        // iterate over the instance annotations and create mallet instances 
        for (Annotation instAnn : instanceAS.inDocumentOrder()) {
            Instance inst = data.extractIndependentFeatures(instAnn, inputAS);
            inst = pipe.instanceFrom(inst);
            // Convert to weka Instance
            weka.core.Instance wekaInstance = CorpusRepresentationWeka.wekaInstanceFromMalletInstance(instances,
                    inst);
            // classify with the weka classifier or predict the numeric value: if the mallet pipe does have
            // a target alphabet we assume classification, otherwise we assume regression
            GateClassification gc = null;
            if (pipe.getTargetAlphabet() == null) {
                // regression
                double result = Double.NaN;
                try {
                    result = wekaClassifier.classifyInstance(wekaInstance);
                } catch (Exception ex) {
                    // Hmm, for now we just log the error and continue, not sure if we should stop here!
                    ex.printStackTrace(System.err);
                    Logger.getLogger(EngineWeka.class.getName()).log(Level.SEVERE, null, ex);
                }
                //gc = new GateClassification(instAnn, (result==Double.NaN ? null : String.valueOf(result)), 1.0);
                gc = new GateClassification(instAnn, result);
            } else {
                // classification

                // Weka AbstractClassifier already handles the situation correctly when 
                // distributionForInstance is not implemented by the classifier: in that case
                // is calls classifyInstance and returns an array of size numClasses where
                // the entry of the target class is set to 1.0 except when the classification is a missing
                // value, then all class probabilities will be 0.0
                // If distributionForInstance is implemented for the algorithm, we should get
                // the probabilities or all zeros for missing class from the algorithm.
                double[] predictionDistribution = new double[0];
                try {
                    //System.err.println("classifying instance "+wekaInstance.toString());
                    predictionDistribution = wekaClassifier.distributionForInstance(wekaInstance);
                } catch (Exception ex) {
                    throw new RuntimeException(
                            "Weka classifier error in document " + instanceAS.getDocument().getName(), ex);
                }
                // This is classification, we should always get a distribution list > 1
                if (predictionDistribution.length < 2) {
                    throw new RuntimeException("Classifier returned less than 2 probabilities: "
                            + predictionDistribution.length + "for instance" + wekaInstance);
                }
                double bestprob = 0.0;
                int bestlabel = 0;
                /*
                System.err.print("DEBUG: got classes from pipe: ");
                  Object[] cls = pipe.getTargetAlphabet().toArray();
                boolean first = true;
                for(Object cl : cls) {
                  if(first) { first = false; } else { System.err.print(", "); }
                  System.err.print(">"+cl+"<");
                }
                System.err.println();
                 */
                List<String> classList = new ArrayList<String>();
                List<Double> confidenceList = new ArrayList<Double>();
                for (int i = 0; i < predictionDistribution.length; i++) {
                    int thislabel = i;
                    double thisprob = predictionDistribution[i];
                    String labelstr = (String) pipe.getTargetAlphabet().lookupObject(thislabel);
                    classList.add(labelstr);
                    confidenceList.add(thisprob);
                    if (thisprob > bestprob) {
                        bestlabel = thislabel;
                        bestprob = thisprob;
                    }
                } // end for i < predictionDistribution.length

                String cl = (String) pipe.getTargetAlphabet().lookupObject(bestlabel);

                gc = new GateClassification(instAnn, cl, bestprob, classList, confidenceList);
            }
            gcs.add(gc);
        }
        data.startGrowth();
        return gcs;
    }

    @Override
    public void initializeAlgorithm(Algorithm algorithm, String parms) {
        Class trainerClass = algorithm.getTrainerClass();
        System.err.println("LF DEBUG: trying to initialize trainer class " + trainerClass);
        try {
            trainer = trainerClass.newInstance();
        } catch (Exception ex) {
            throw new GateRuntimeException("Could not create Weka trining algorithm for class " + trainerClass);
        }
    }

    @Override
    public void saveModel(File directory) {
        File outFile = new File(directory, FILENAME_MODEL);
        ObjectOutputStream oos = null;
        try {
            oos = new ObjectOutputStream(new FileOutputStream(outFile));
            oos.writeObject(model);
        } catch (Exception ex) {
            throw new GateRuntimeException("Could not save Weka model to " + outFile, ex);
        } finally {
            if (oos != null)
                try {
                    oos.close();
                } catch (IOException ex) {
                    // ignore
                }
        }
    }

}