Example usage for weka.classifiers Evaluation Evaluation

List of usage examples for weka.classifiers Evaluation Evaluation

Introduction

In this page you can find the example usage for weka.classifiers Evaluation Evaluation.

Prototype

public Evaluation(Instances data) throws Exception 

Source Link

Usage

From source file:boostingPL.boosting.SAMME.java

License:Open Source License

public static void main(String[] args) throws Exception {
    java.io.File inputFile = new java.io.File(args[0]);
    ArffLoader atf = new ArffLoader();
    atf.setFile(inputFile);/* w ww .j  ava 2s .com*/
    Instances training = atf.getDataSet();
    training.setClassIndex(training.numAttributes() - 1);
    //Instances testing = new Instances(training);

    int iterationNum = 100;
    SAMME samme = new SAMME(training, iterationNum);
    for (int t = 0; t < iterationNum; t++) {
        samme.run(t);
    }

    java.io.File inputFilet = new java.io.File(args[1]);
    ArffLoader atft = new ArffLoader();
    atft.setFile(inputFilet);
    Instances testing = atft.getDataSet();
    testing.setClassIndex(testing.numAttributes() - 1);

    Evaluation eval = new Evaluation(testing);
    for (Instance inst : testing) {
        eval.evaluateModelOnceAndRecordPrediction(samme, inst);
    }
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toMatrixString());
}

From source file:boostingPL.MR.AdaBoostPLTestMapper.java

License:Open Source License

protected void setup(Context context) throws IOException, InterruptedException {
    // classifier file
    Path path = new Path(context.getConfiguration().get("BoostingPL.modelPath") + "/part-r-00000");
    String boostingName = context.getConfiguration().get("BoostingPL.boostingName");
    boostingPL = BoostingPLFactory.createBoostingPL(boostingName, context.getConfiguration(), path);

    // testing dataset metadata
    String pathSrc = context.getConfiguration().get("BoostingPL.metadata");
    FileSystem hdfs = FileSystem.get(context.getConfiguration());
    FSDataInputStream dis = new FSDataInputStream(hdfs.open(new Path(pathSrc)));
    LineReader in = new LineReader(dis);
    insts = InstancesHelper.createInstancesFromMetadata(in);
    in.close();//from  w  w  w  . j a  va 2s  . com
    dis.close();

    try {
        eval = new Evaluation(insts);
    } catch (Exception e) {
        LOG.error("[BoostingPL-Test]: Evaluation init error!");
        e.printStackTrace();
    }
    instanceCounter = context.getCounter("BoostingPL", "Number of instances");
}

From source file:boostingPL.MR.AdaBoostPLTestReducer.java

License:Open Source License

protected void setup(Context context) throws IOException, InterruptedException {
    // classifier file
    Path path = new Path(context.getConfiguration().get("BoostingPL.modelPath") + "/part-r-00000");
    String boostingName = context.getConfiguration().get("BoostingPL.boostingName");
    boostingPL = BoostingPLFactory.createBoostingPL(boostingName, context.getConfiguration(), path);

    // testing dataset metadata
    String pathSrc = context.getConfiguration().get("BoostingPL.metadata");
    FileSystem hdfs = FileSystem.get(context.getConfiguration());
    FSDataInputStream dis = new FSDataInputStream(hdfs.open(new Path(pathSrc)));
    LineReader in = new LineReader(dis);
    insts = InstancesHelper.createInstancesFromMetadata(in);
    in.close();/*from   w  w  w.  ja  v a  2 s. c o m*/
    dis.close();

    try {
        eval = new Evaluation(insts);
    } catch (Exception e) {
        LOG.error("[BoostingPL-Test]: Evaluation init error!");
        e.printStackTrace();
    }
}

From source file:br.unicamp.ic.recod.gpsi.gp.gpsiJGAPRoiFitnessFunction.java

@Override
protected double evaluate(IGPProgram igpp) {

    double mean_accuracy = 0.0;
    Object[] noargs = new Object[0];

    gpsiRoiBandCombiner roiBandCombinator = new gpsiRoiBandCombiner(new gpsiJGAPVoxelCombiner(super.b, igpp));
    // TODO: The ROI descriptors must combine the images first
    //roiBandCombinator.combineEntity(this.dataset.getTrainingEntities());

    gpsiMLDataset mlDataset = new gpsiMLDataset(this.descriptor);
    try {/*ww  w  .j  a v  a 2  s  .c  om*/
        mlDataset.loadWholeDataset(this.dataset, true);
    } catch (Exception ex) {
        Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex);
    }

    int dimensionality = mlDataset.getDimensionality();
    int n_classes = mlDataset.getTrainingEntities().keySet().size();
    int n_entities = mlDataset.getNumberOfTrainingEntities();
    ArrayList<Byte> listOfClasses = new ArrayList<>(mlDataset.getTrainingEntities().keySet());

    Attribute[] attributes = new Attribute[dimensionality];
    FastVector fvClassVal = new FastVector(n_classes);

    int i, j;
    for (i = 0; i < dimensionality; i++)
        attributes[i] = new Attribute("f" + Integer.toString(i));
    for (i = 0; i < n_classes; i++)
        fvClassVal.addElement(Integer.toString(listOfClasses.get(i)));

    Attribute classes = new Attribute("class", fvClassVal);

    FastVector fvWekaAttributes = new FastVector(dimensionality + 1);

    for (i = 0; i < dimensionality; i++)
        fvWekaAttributes.addElement(attributes[i]);
    fvWekaAttributes.addElement(classes);

    Instances instances = new Instances("Rel", fvWekaAttributes, n_entities);
    instances.setClassIndex(dimensionality);

    Instance iExample;
    for (byte label : mlDataset.getTrainingEntities().keySet()) {
        for (double[] featureVector : mlDataset.getTrainingEntities().get(label)) {
            iExample = new Instance(dimensionality + 1);
            for (j = 0; j < dimensionality; j++)
                iExample.setValue(i, featureVector[i]);
            iExample.setValue(dimensionality, label);
            instances.add(iExample);
        }
    }

    int folds = 5;
    Random rand = new Random();
    Instances randData = new Instances(instances);
    randData.randomize(rand);

    Instances trainingSet, testingSet;
    Classifier cModel;
    Evaluation eTest;
    try {
        for (i = 0; i < folds; i++) {
            cModel = (Classifier) new SimpleLogistic();
            trainingSet = randData.trainCV(folds, i);
            testingSet = randData.testCV(folds, i);

            cModel.buildClassifier(trainingSet);

            eTest = new Evaluation(trainingSet);
            eTest.evaluateModel(cModel, testingSet);

            mean_accuracy += eTest.pctCorrect();

        }
    } catch (Exception ex) {
        Logger.getLogger(gpsiJGAPRoiFitnessFunction.class.getName()).log(Level.SEVERE, null, ex);
    }

    mean_accuracy /= (folds * 100);

    return mean_accuracy;

}

From source file:c4.pkg5crossv.Classifier.java

public static void trainAndTest() throws FileNotFoundException, IOException, Exception {

    Instances data = DataLoad.loadData("./src/data/irysy.arff");
    data.setClassIndex(data.numAttributes() - 1);

    //Losowy podzial tablicy
    data.randomize(new Random());
    double percent = 60.0;
    int trainSize = (int) Math.round(data.numInstances() * percent / 100);
    int testSize = data.numInstances() - trainSize;
    Instances trainData = new Instances(data, 0, trainSize);
    Instances testData = new Instances(data, trainSize, testSize);

    String[] options = Utils.splitOptions("-U -M 10");
    J48 tree = new J48();
    tree.setOptions(options);/* w w w  .j ava  2  s. c  o  m*/
    tree.buildClassifier(trainData);

    Evaluation eval2 = new Evaluation(trainData);
    eval2.crossValidateModel(tree, testData, 10, new Random(1)); // 5 - fold
    System.out.println(eval2.toSummaryString("Wyniki:", false)); //Wypisanie testovania cross validation
}

From source file:ca.uottawa.balie.WekaLearner.java

License:Open Source License

/**
 * Approximate training set error.//from   ww w.  ja  va 2  s . c  o m
 * 
 * @return evaluation module from which many types of errors are exposed (e.g.: mean absolute error)
 */
public Evaluation EstimateConfidence() {
    Evaluation evaluation = null;
    try {
        evaluation = new Evaluation(m_TrainingSet);
        evaluation.crossValidateModel(m_Scheme, m_TrainingSet, 10, new Random());
    } catch (Exception e) {
        System.out.println(e.getMessage());
    }
    // which error is the best? depends on the application.
    return evaluation;
}

From source file:ca.uottawa.balie.WekaLearner.java

License:Open Source License

/**
 * Test the learned model.//from  w  w  w . ja  va2  s.  c o m
 * 
 * @return A summary string of the performance of the classifier
 */
public String TestModel() {
    if (DEBUG)
        DebugInfo.Out("Testing on " + m_TestingSet.numInstances() + " instances");
    Evaluation evaluation = null;
    try {
        evaluation = new Evaluation(m_TrainingSet);
        evaluation.evaluateModel(m_Scheme, m_TestingSet);
    } catch (Exception e) {
        System.out.println(e.getMessage());
    }
    String strSummary = evaluation.toSummaryString();
    strSummary += "\n\nConfusion Matrix: \n\n";
    m_ConfusionMatrix = evaluation.confusionMatrix();
    for (int i = 0; i != m_ConfusionMatrix.length; ++i) {
        for (int j = 0; j != m_ConfusionMatrix[i].length; ++j) {
            strSummary += String.valueOf(m_ConfusionMatrix[i][j]) + "\t";
        }
        strSummary += "\n";
    }
    return strSummary;
}

From source file:ca.uqac.florentinth.speakerauthentication.Learning.Learning.java

License:Apache License

public void trainClassifier(Classifier classifier, FileReader trainingDataset, FileOutputStream trainingModel,
        Integer crossValidationFoldNumber) throws Exception {
    Instances instances = new Instances(new BufferedReader(trainingDataset));

    switch (classifier) {
    case KNN://  w  ww.jav a  2s. c o  m
        int K = (int) Math.ceil(Math.sqrt(instances.numInstances()));
        this.classifier = new IBk(K);
        break;
    case NB:
        this.classifier = new NaiveBayes();
    }

    if (instances.classIndex() == -1) {
        instances.setClassIndex(instances.numAttributes() - 1);
    }

    this.classifier.buildClassifier(instances);

    if (crossValidationFoldNumber > 0) {
        Evaluation evaluation = new Evaluation(instances);
        evaluation.crossValidateModel(this.classifier, instances, crossValidationFoldNumber, new Random(1));
        kappa = evaluation.kappa();
        fMeasure = evaluation.weightedFMeasure();
        confusionMatrix = evaluation.toMatrixString("Confusion matrix: ");
    }

    ObjectOutputStream outputStream = new ObjectOutputStream(trainingModel);
    outputStream.writeObject(this.classifier);
    outputStream.flush();
    outputStream.close();
}

From source file:CEP.CEPListener.java

public void update(EventBean[] newData, EventBean[] oldData) {

    System.out.println("Event received: " + newData[0].getUnderlying());

    if (newData.length > 0) {
        try {//from w w w  .j  av  a 2s  .  c  o m
            if (training) {
                if (train == null) {
                    train = HeaderManager.GetEmptyStructure();
                }
                for (EventBean bean : newData) {
                    Object inst = bean.getUnderlying();
                    train.add((Instance) inst);
                }
                if (train.size() >= sampleSize) {
                    tree.buildClassifier(train);
                    training = false;
                }
            } else {
                if (data == null) {
                    data = HeaderManager.GetStructure();
                }

                data = SetDuration(data);
                cumulative += data.size();

                for (EventBean bean : newData) {
                    Object inst = bean.getUnderlying();
                    data.add((Instance) inst);
                }
                for (int i = data.numInstances() - newData.length; i < data.numInstances(); i++) {
                    double pred = tree.classifyInstance(data.instance(i));
                    System.out.print("ID: " + data.instance(i).value(0));
                    System.out.print(
                            ", actual: " + data.classAttribute().value((int) data.instance(i).classValue()));
                    System.out.println(", predicted: " + data.classAttribute().value((int) pred));
                    Evaluation eval = new Evaluation(data);
                    if ((accuracy = eval.rootMeanSquaredError()) < 0.7) {

                        training = true;
                        train.clear();
                        train = null;
                    }
                    System.out.print("Accuracy: " + accuracy);
                }
            }
        } catch (InterruptedException ex) {
            Logger.getLogger(CEPListener.class.getName()).log(Level.SEVERE, null, ex);
        } catch (Exception ex) {
            Logger.getLogger(CEPListener.class.getName()).log(Level.SEVERE, null, ex);
        }

    }
}

From source file:cezeri.evaluater.FactoryEvaluation.java

public static Evaluation performCrossValidate(Classifier model, Instances datax, int folds, boolean show_text,
        boolean show_plot, TFigureAttribute attr) {
    Random rand = new Random(1);
    Instances randData = new Instances(datax);
    randData.randomize(rand);//  w ww  .  j  a v a2 s.  c o m
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }
    Evaluation eval = null;
    try {
        // perform cross-validation
        eval = new Evaluation(randData);
        //            double[] simulated = new double[0];
        //            double[] observed = new double[0];
        //            double[] sim = new double[0];
        //            double[] obs = new double[0];
        for (int n = 0; n < folds; n++) {
            Instances train = randData.trainCV(folds, n, rand);
            Instances validation = randData.testCV(folds, n);
            // build and evaluate classifier
            Classifier clsCopy = Classifier.makeCopy(model);
            clsCopy.buildClassifier(train);

            //                sim = eval.evaluateModel(clsCopy, validation);
            //                obs = validation.attributeToDoubleArray(validation.classIndex());
            //                if (show_plot) {
            //                    double[][] d = new double[2][sim.length];
            //                    d[0] = obs;
            //                    d[1] = sim;
            //                    CMatrix f1 = CMatrix.getInstance(d);
            //                    f1.transpose().plot(attr);
            //                }
            //                if (show_text) {
            //                    // output evaluation
            //                    System.out.println();
            //                    System.out.println("=== Setup for each Cross Validation fold===");
            //                    System.out.println("Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            //                    System.out.println("Dataset: " + randData.relationName());
            //                    System.out.println("Folds: " + folds);
            //                    System.out.println("Seed: " + 1);
            //                    System.out.println();
            //                    System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
            //                }
            simulated = FactoryUtils.concatenate(simulated, eval.evaluateModel(clsCopy, validation));
            observed = FactoryUtils.concatenate(observed,
                    validation.attributeToDoubleArray(validation.classIndex()));
            //                simulated = FactoryUtils.mean(simulated,eval.evaluateModel(clsCopy, validation));
            //                observed = FactoryUtils.mean(observed,validation.attributeToDoubleArray(validation.classIndex()));
        }

        if (show_plot) {
            double[][] d = new double[2][simulated.length];
            d[0] = observed;
            d[1] = simulated;
            CMatrix f1 = CMatrix.getInstance(d);
            attr.figureCaption = "overall performance";
            f1.transpose().plot(attr);
        }
        if (show_text) {
            // output evaluation
            System.out.println();
            System.out.println("=== Setup for Overall Cross Validation===");
            System.out.println(
                    "Classifier: " + model.getClass().getName() + " " + Utils.joinOptions(model.getOptions()));
            System.out.println("Dataset: " + randData.relationName());
            System.out.println("Folds: " + folds);
            System.out.println("Seed: " + 1);
            System.out.println();
            System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
        }
    } catch (Exception ex) {
        Logger.getLogger(FactoryEvaluation.class.getName()).log(Level.SEVERE, null, ex);
    }
    return eval;
}