Example usage for weka.classifiers AbstractClassifier makeCopy

List of usage examples for weka.classifiers AbstractClassifier makeCopy

Introduction

In this page you can find the example usage for weka.classifiers AbstractClassifier makeCopy.

Prototype

public static Classifier makeCopy(Classifier model) throws Exception 

Source Link

Document

Creates a deep copy of the given classifier using serialization.

Usage

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/*  w  ww  .  j  a  v a2 s .  c o  m*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput/*from ww w.ja v  a2 s  .  co  m*/
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.CrossValidation.java

static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds,
        String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);/* w  ww  . j  a v a2s.c  o m*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(cls)));
        } catch (Exception ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }

        //TODO: use Config.getNumThreads() for limiting these::
        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    try {
        cls.buildClassifier(data);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
    }

    for (Thread foldThread : foldThreads) {
        try {
            foldThread.join();
        } catch (InterruptedException ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, cls);
            } catch (Exception ex) {
                Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.NLPSystem.java

private String crossValidate(int seed, int folds, String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation");
    PerformanceCounters.startTimer("cross-validation init");

    AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(trainingSet);
    randData.randomize(rand);/*from  ww w. j av  a2s.  co  m*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(abstractClassifier)));
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init");
    PerformanceCounters.startTimer("cross-validation folds+train");

    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    for (Thread foldThread : foldThreads) {
        while (foldThread.isAlive()) {
            try {
                foldThread.join();
            } catch (InterruptedException ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation folds+train");
    PerformanceCounters.startTimer("cross-validation post");
    // evaluation for output:
    String out = String.format(
            "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n",
            abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()),
            trainingSet.relationName(), folds, seed,
            eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false));

    try {
        crossValidationPearsonsCorrelation = eval.correlationCoefficient();
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
    }
    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, abstractClassifier);
            } catch (Exception ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    classifierBuiltWithCrossValidation = true;
    PerformanceCounters.stopTimer("cross-validation post");
    PerformanceCounters.stopTimer("cross-validation");
    return out;
}

From source file:com.actelion.research.orbit.imageAnalysis.models.ClassifierWrapper.java

License:Open Source License

public static ClassifierWrapper makeCopy(ClassifierWrapper classifierWrapper) throws Exception {
    if (classifierWrapper == null)
        return null;
    if (classifierWrapper.getWrapperType() == WRAPPERTYPE_CLASSIFIER) {
        //return new ClassifierWrapper(Classifier.makeCopy(classifierWrapper.getClassifier()));
        ClassifierWrapper cw = new ClassifierWrapper(
                (AbstractClassifier.makeCopy(classifierWrapper.getClassifier()))); // weka 3.7.1
        cw.setBinaryClassification(classifierWrapper.getBinaryClassification());
        cw.setBuild(classifierWrapper.isBuild);
        return cw;
    } else {//from   w  w  w .  java 2 s  .  c  o  m
        ClassifierWrapper cw = new ClassifierWrapper(
                AbstractClusterer.makeCopy(classifierWrapper.getClusterer()));
        cw.setBinaryClassification(classifierWrapper.getBinaryClassification());
        cw.setBuild(classifierWrapper.isBuild);
        return cw;
    }
}

From source file:com.deafgoat.ml.prognosticator.AppClassifier.java

License:Apache License

/**
 * Perform cross-validation on data set/builds model
 * /*from   ww w  .j  a  va  2  s.  c  o  m*/
 * @throws Exception
 */
public void crossValidate() throws Exception {
    // stratify nominal target class
    if (_trainInstances.classAttribute().isNominal()) {
        _trainInstances.stratify(_folds);
    }
    _eval = new Evaluation(_trainInstances);
    for (int n = 0; n < _folds; n++) {

        if (_logger.isDebugEnabled()) {
            _logger.debug("Cross validation fold: " + (n + 1));
        }
        _train = _trainInstances.trainCV(_folds, n);
        _test = _trainInstances.testCV(_folds, n);
        _clsCopy = AbstractClassifier.makeCopy(_cls);
        try {
            _clsCopy.buildClassifier(_train);
        } catch (Exception e) {
            _logger.debug(_config._classifier + " can not handle " + getAttributeType(_test.classAttribute())
                    + " class attributes");
        }

        try {
            _eval.evaluateModel(_clsCopy, _test);
        } catch (Exception e) {
            _logger.debug("Can not evaluate model");
        }
    }

    if (_config._writeToMongoDB) {
        _logger.info("Writing model to mongoDB");
        // save the trained model
        saveModel();
        // save CV performance of trained model
        writeToMongoDB(_eval);
    }

    if (_config._writeToFile) {
        _logger.info("Writing model to file");
        SerializationHelper.write(_config._modelFile, _clsCopy);
    }
}

From source file:de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator.java

License:Open Source License

public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = getClassifier(wekaClassifier);

    // Set up the random number generator
    long seed = new Date().getTime();
    Random random = new Random(seed);

    // Add IDs to the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);

    // Instantiate the Remove filter
    Remove removeIDFilter = new Remove();
    removeIDFilter.setAttributeIndices("first");

    // Randomize the data
    data.randomize(random);/*www  .j  a v  a2s . co m*/

    // Perform cross-validation
    Instances predictedData = null;
    Evaluation eval = new Evaluation(data);

    for (int n = 0; n < folds; n++) {
        Instances train = data.trainCV(folds, n, random);
        Instances test = data.testCV(folds, n);

        // Apply log filter
        //          Filter logFilter = new LogFilter();
        //           logFilter.setInputFormat(train);
        //           train = Filter.useFilter(train, logFilter);        
        //           logFilter.setInputFormat(test);
        //           test = Filter.useFilter(test, logFilter);

        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);

        // Build the classifier
        filteredClassifier.buildClassifier(train);

        // Evaluate
        eval.evaluateModel(filteredClassifier, test);

        // Add predictions
        AddClassification filter = new AddClassification();
        filter.setClassifier(filteredClassifier);
        filter.setOutputClassification(true);
        filter.setOutputDistribution(false);
        filter.setOutputErrorFlag(true);
        filter.setInputFormat(train);
        Filter.useFilter(train, filter); // trains the classifier

        Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
        if (predictedData == null)
            predictedData = new Instances(pred, 0);
        for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));
    }

    // Prepare output classification
    String[] scores = new String[predictedData.numInstances()];

    for (Instance predInst : predictedData) {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

        int valueIdx = predictedData.numAttributes() - 2;

        String value = predInst.stringValue(predInst.attribute(valueIdx));

        scores[id] = value;
    }

    // Output
    StringBuilder sb = new StringBuilder();
    for (String score : scores)
        sb.append(score.toString() + LF);

    FileUtils.writeStringToFile(
            new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"),
            sb.toString());
}

From source file:dkpro.similarity.experiments.rte.util.Evaluator.java

License:Open Source License

public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset)
        throws Exception {
    Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);

    // Set up the random number generator
    long seed = new Date().getTime();
    Random random = new Random(seed);

    // Add IDs to the train instances and get the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" });
    Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff");
    train.setClassIndex(train.numAttributes() - 1);

    // Add IDs to the test instances and get the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + testDataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" });
    Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff");
    test.setClassIndex(test.numAttributes() - 1);

    // Instantiate the Remove filter
    Remove removeIDFilter = new Remove();
    removeIDFilter.setAttributeIndices("first");

    // Randomize the data
    test.randomize(random);//from ww w  .  ja  va 2 s .  co m

    // Apply log filter
    //       Filter logFilter = new LogFilter();
    //       logFilter.setInputFormat(train);
    //       train = Filter.useFilter(train, logFilter);        
    //       logFilter.setInputFormat(test);
    //       test = Filter.useFilter(test, logFilter);

    // Copy the classifier
    Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

    // Instantiate the FilteredClassifier
    FilteredClassifier filteredClassifier = new FilteredClassifier();
    filteredClassifier.setFilter(removeIDFilter);
    filteredClassifier.setClassifier(classifier);

    // Build the classifier
    filteredClassifier.buildClassifier(train);

    // Prepare the output buffer 
    AbstractOutput output = new PlainText();
    output.setBuffer(new StringBuffer());
    output.setHeader(test);
    output.setAttributes("first");

    Evaluation eval = new Evaluation(train);
    eval.evaluateModel(filteredClassifier, test, output);

    // Convert predictions to CSV
    // Format: inst#, actual, predicted, error, probability, (ID)
    String[] scores = new String[new Double(eval.numInstances()).intValue()];
    double[] probabilities = new double[new Double(eval.numInstances()).intValue()];
    for (String line : output.getBuffer().toString().split("\n")) {
        String[] linesplit = line.split("\\s+");

        // If there's been an error, the length of linesplit is 6, otherwise 5,
        // due to the error flag "+"

        int id;
        String expectedValue, classification;
        double probability;

        if (line.contains("+")) {
            id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1));
            expectedValue = linesplit[2].substring(2);
            classification = linesplit[3].substring(2);
            probability = Double.parseDouble(linesplit[5]);
        } else {
            id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1));
            expectedValue = linesplit[2].substring(2);
            classification = linesplit[3].substring(2);
            probability = Double.parseDouble(linesplit[4]);
        }

        scores[id - 1] = classification;
        probabilities[id - 1] = probability;
    }

    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());

    // Output classifications
    StringBuilder sb = new StringBuilder();
    for (String score : scores)
        sb.append(score.toString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
            + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"), sb.toString());

    // Output probabilities
    sb = new StringBuilder();
    for (Double probability : probabilities)
        sb.append(probability.toString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
            + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"), sb.toString());

    // Output predictions
    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
            + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"),
            output.getBuffer().toString());

    // Output meta information
    sb = new StringBuilder();
    sb.append(classifier.toString() + LF);
    sb.append(eval.toSummaryString() + LF);
    sb.append(eval.toMatrixString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/"
            + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"), sb.toString());
}

From source file:dkpro.similarity.experiments.rte.util.Evaluator.java

License:Open Source License

public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);

    // Set up the random number generator
    long seed = new Date().getTime();
    Random random = new Random(seed);

    // Add IDs to the instances
    AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o",
            MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);

    // Instantiate the Remove filter
    Remove removeIDFilter = new Remove();
    removeIDFilter.setAttributeIndices("first");

    // Randomize the data
    data.randomize(random);/*from   www. ja va 2s  .c  om*/

    // Perform cross-validation
    Instances predictedData = null;
    Evaluation eval = new Evaluation(data);

    for (int n = 0; n < folds; n++) {
        Instances train = data.trainCV(folds, n, random);
        Instances test = data.testCV(folds, n);

        // Apply log filter
        //          Filter logFilter = new LogFilter();
        //           logFilter.setInputFormat(train);
        //           train = Filter.useFilter(train, logFilter);        
        //           logFilter.setInputFormat(test);
        //           test = Filter.useFilter(test, logFilter);

        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);

        // Build the classifier
        filteredClassifier.buildClassifier(train);

        // Evaluate
        eval.evaluateModel(filteredClassifier, test);

        // Add predictions
        AddClassification filter = new AddClassification();
        filter.setClassifier(classifier);
        filter.setOutputClassification(true);
        filter.setOutputDistribution(false);
        filter.setOutputErrorFlag(true);
        filter.setInputFormat(train);
        Filter.useFilter(train, filter); // trains the classifier

        Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
        if (predictedData == null)
            predictedData = new Instances(pred, 0);
        for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));
    }

    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());

    // Prepare output scores
    String[] scores = new String[predictedData.numInstances()];

    for (Instance predInst : predictedData) {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

        int valueIdx = predictedData.numAttributes() - 2;

        String value = predInst.stringValue(predInst.attribute(valueIdx));

        scores[id] = value;
    }

    // Output classifications
    StringBuilder sb = new StringBuilder();
    for (String score : scores)
        sb.append(score.toString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
            + "/" + dataset.toString() + ".csv"), sb.toString());

    // Output prediction arff
    DataSink.write(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/"
            + dataset.toString() + ".predicted.arff", predictedData);

    // Output meta information
    sb = new StringBuilder();
    sb.append(baseClassifier.toString() + LF);
    sb.append(eval.toSummaryString() + LF);
    sb.append(eval.toMatrixString() + LF);

    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString()
            + "/" + dataset.toString() + ".meta.txt"), sb.toString());
}

From source file:dkpro.similarity.experiments.sts2013.util.Evaluator.java

License:Open Source License

public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception {
    for (Dataset dataset : datasets) {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = new LinearRegression();

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                        + "-plusIDs.arff" });
        Instances data = DataSource.read(
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff");
        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);//  ww w.j a  va  2s  .co m

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(classifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null) {
                predictedData = new Instances(pred, 0);
            }
            for (int j = 0; j < pred.numInstances(); j++) {
                predictedData.add(pred.instance(j));
            }
        }

        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            double value = predInst.value(predInst.attribute(valueIdx));

            scores[id] = value;

            // Limit to interval [0;5]
            if (scores[id] > 5.0) {
                scores[id] = 5.0;
            }
            if (scores[id] < 0.0) {
                scores[id] = 0.0;
            }
        }

        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
            sb.append(score.toString() + LF);
        }

        FileUtils.writeStringToFile(
                new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
                sb.toString());
    }
}