Example usage for weka.classifiers AbstractClassifier makeCopy

List of usage examples for weka.classifiers AbstractClassifier makeCopy

Introduction

In this page you can find the example usage for weka.classifiers AbstractClassifier makeCopy.

Prototype

public static Classifier makeCopy(Classifier model) throws Exception 

Source Link

Document

Creates a deep copy of the given classifier using serialization.

Usage

From source file:dkpro.similarity.experiments.sts2013baseline.util.Evaluator.java

License:Open Source License

public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception {
    for (Dataset dataset : datasets) {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = new LinearRegression();

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                        + "-plusIDs.arff" });

        String location = MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                + "-plusIDs.arff";

        Instances data = DataSource.read(location);

        if (data == null) {
            throw new IOException("Could not load data from: " + location);
        }//  www  .  ja  v a 2s.  co  m

        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(classifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null) {
                predictedData = new Instances(pred, 0);
            }
            for (int j = 0; j < pred.numInstances(); j++) {
                predictedData.add(pred.instance(j));
            }
        }

        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            double value = predInst.value(predInst.attribute(valueIdx));

            scores[id] = value;

            // Limit to interval [0;5]
            if (scores[id] > 5.0) {
                scores[id] = 5.0;
            }
            if (scores[id] < 0.0) {
                scores[id] = 0.0;
            }
        }

        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
            sb.append(score.toString() + LF);
        }

        FileUtils.writeStringToFile(
                new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
                sb.toString());
    }
}

From source file:edu.cuny.qc.speech.AuToBI.classifier.WekaClassifier.java

License:Open Source License

/**
 * Constructs a copy of the object.// ww w .j  av a 2  s  .  com
 *
 * @return a copy of the object.
 */
public AuToBIClassifier newInstance() {
    try {
        return new WekaClassifier(AbstractClassifier.makeCopy(weka_classifier));
    } catch (Exception e) {
        e.printStackTrace();
    }
    return null;
}

From source file:fantail.algorithms.RankingViaRegression.java

License:Open Source License

@Override
public void buildRanker(Instances data) throws Exception {

    Instances workingData = new Instances(data);
    //Instance instTemp = workingData.instance(0);

    //m_LastFeatureIndex = workingData.numAttributes() - 1;
    m_NumFeatures = workingData.numAttributes() - 1;
    m_NumTargets = Tools.getNumberTargets(data);
    m_Classifiers = new AbstractClassifier[m_NumTargets];

    for (int i = 0; i < m_NumTargets; i++) {
        weka.classifiers.functions.LinearRegression lr = new weka.classifiers.functions.LinearRegression();
        m_Classifiers[i] = AbstractClassifier.makeCopy(lr);
    }//from   w  ww  .  java  2s . c o  m

    Instances[] trainingSets = new Instances[m_NumTargets];

    for (int t = 0; t < m_NumTargets; t++) {

        ArrayList attributes = new ArrayList();
        for (int i = 0; i < m_NumFeatures; i++) {
            attributes.add(new Attribute(workingData.attribute(i).name()));
        }

        String targetName = "att-" + (t + 1);
        attributes.add(new Attribute(targetName));

        trainingSets[t] = new Instances("data-" + targetName, attributes, 0);

        for (int j = 0; j < workingData.numInstances(); j++) {
            Instance metaInst = workingData.instance(j);
            double[] ranking = Tools.getTargetVector(metaInst);
            double[] values = new double[trainingSets[t].numAttributes()];

            for (int m = 0; m < (trainingSets[t].numAttributes() - 1); m++) {
                values[m] = metaInst.value(m);
            }
            values[values.length - 1] = ranking[t];
            trainingSets[t].add(new DenseInstance(1.0, values));
        }

        trainingSets[t].setClassIndex(trainingSets[t].numAttributes() - 1);
        m_Classifiers[t].buildClassifier(trainingSets[t]);
    }

    m_TempHeader = new Instances(trainingSets[0], 0);
}

From source file:jjj.asap.sas.ensemble.impl.StackedClassifier.java

License:Open Source License

@Override
public StrongLearner build(int essaySet, String ensembleName, List<WeakLearner> learners) {

    if (learners.isEmpty()) {
        return StrongLearner.NO_MODEL[essaySet - 1];
    }/*from w ww .ja  va  2s  .  c o m*/

    StrongLearner strong = new StrongLearner();

    // training
    try {

        Instances metaData = getMetaDataset(essaySet, learners);

        // hack
        //Instances hack = getMetaDataset(essaySet, learners);
        //hack.setRelationName("stacking"+essaySet);
        //Dataset.save("etc/stacking" + essaySet + ".arff", hack);
        // end hack

        Classifier metaClassifier = AbstractClassifier.makeCopy(prototype);

        Weka.trainClassifier(metaData, metaClassifier);
        Map<Double, double[]> probs = Weka.classifyInstances(metaData, metaClassifier);
        Map<Double, Double> preds = Model.getPredictions(essaySet, probs);
        double kappa = Calc.kappa(essaySet, preds, Contest.getGoldStandard(essaySet));

        strong.setKappa(kappa);
        strong.setPreds(preds);
        strong.setLearners(new ArrayList<WeakLearner>(learners));
        strong.setContext(metaClassifier);

    } catch (Exception e) {
        throw new RuntimeException(e);
    }

    return strong;
}

From source file:jjj.asap.sas.models1.job.BuildBasicMetaCostModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }//from w ww .  java2s.c  o m
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    Map<String, Classifier> prototypes = new HashMap<String, Classifier>();

    // Bagged REPTrees

    Bagging baggedTrees = new Bagging();
    baggedTrees.setNumExecutionSlots(1);
    baggedTrees.setNumIterations(100);
    baggedTrees.setClassifier(new REPTree());
    baggedTrees.setCalcOutOfBag(false);

    prototypes.put("Bagged-REPTrees", baggedTrees);

    // Bagged SMO

    Bagging baggedSVM = new Bagging();
    baggedSVM.setNumExecutionSlots(1);
    baggedSVM.setNumIterations(100);
    baggedSVM.setClassifier(new SMO());
    baggedSVM.setCalcOutOfBag(false);

    prototypes.put("Bagged-SMO", baggedSVM);

    // Meta Cost model for Naive Bayes

    Bagging bagging = new Bagging();
    bagging.setNumExecutionSlots(1);
    bagging.setNumIterations(100);
    bagging.setClassifier(new NaiveBayes());

    CostSensitiveClassifier meta = new CostSensitiveClassifier();
    meta.setClassifier(bagging);
    meta.setMinimizeExpectedCost(true);

    prototypes.put("CostSensitive-MinimizeExpectedCost-NaiveBayes", bagging);

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        // for each prototype classifier
        for (Map.Entry<String, Classifier> prototype : prototypes.entrySet()) {

            // 
            // speical logic for meta cost
            //

            Classifier alg = AbstractClassifier.makeCopy(prototype.getValue());

            if (alg instanceof CostSensitiveClassifier) {

                int essaySet = Contest.getEssaySet(dsn);

                String matrix = Contest.getRubrics(essaySet).size() == 3 ? "cost3.txt" : "cost4.txt";

                ((CostSensitiveClassifier) alg)
                        .setCostMatrix(new CostMatrix(new FileReader("/asap/sas/trunk/" + matrix)));

            }

            // use InfoGain to discard useless attributes

            AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();

            classifier.setEvaluator(new InfoGainAttributeEval());

            Ranker ranker = new Ranker();
            ranker.setThreshold(0.0001);
            classifier.setSearch(ranker);

            classifier.setClassifier(alg);

            queue.add(Job.submit(
                    new ModelBuilder(dsn, "InfoGain-" + prototype.getKey(), classifier, this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:jjj.asap.sas.models1.job.BuildBasicModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }/*w ww .j a v a2s  .com*/
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    Map<String, Classifier> prototypes = new HashMap<String, Classifier>();

    // bayes

    BayesNet net = new BayesNet();
    net.setEstimator(new BMAEstimator());
    prototypes.put("BayesNet", net);

    prototypes.put("NaiveBayes", new NaiveBayes());

    // functions

    prototypes.put("RBFNetwork", new RBFNetwork());
    prototypes.put("SMO", new SMO());

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        // for each prototype classifier
        for (Map.Entry<String, Classifier> prototype : prototypes.entrySet()) {

            // use InfoGain to discard useless attributes

            AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();

            classifier.setEvaluator(new InfoGainAttributeEval());

            Ranker ranker = new Ranker();
            ranker.setThreshold(0.0001);
            classifier.setSearch(ranker);

            classifier.setClassifier(AbstractClassifier.makeCopy(prototype.getValue()));

            queue.add(Job.submit(
                    new ModelBuilder(dsn, "InfoGain-" + prototype.getKey(), classifier, this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:jjj.asap.sas.models1.job.BuildBasicModels2.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }//www. ja  va 2  s . co m
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    Map<String, Classifier> prototypes = new HashMap<String, Classifier>();

    // models

    prototypes.put("NBTree", new NBTree());
    prototypes.put("Logistic", new Logistic());

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        // for each prototype classifier
        for (Map.Entry<String, Classifier> prototype : prototypes.entrySet()) {

            // use InfoGain to discard useless attributes

            AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();

            classifier.setEvaluator(new InfoGainAttributeEval());

            Ranker ranker = new Ranker();
            ranker.setThreshold(0.0001);
            classifier.setSearch(ranker);

            classifier.setClassifier(AbstractClassifier.makeCopy(prototype.getValue()));

            queue.add(Job.submit(
                    new ModelBuilder(dsn, "InfoGain-" + prototype.getKey(), classifier, this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:jjj.asap.sas.models1.job.BuildPLSModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }/*from  ww  w .  j  a v  a 2  s .co  m*/
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // Standard PLS

    PLSClassifier pls = new PLSClassifier();
    PLSFilter filter = (PLSFilter) pls.getFilter();
    filter.setNumComponents(5);
    filter.setPreprocessing(NONE);

    // centered PLS

    PLSClassifier plsc = new PLSClassifier();
    PLSFilter center = (PLSFilter) plsc.getFilter();
    center.setNumComponents(5);
    center.setPreprocessing(CENTER);

    // standardized PLS

    PLSClassifier plss = new PLSClassifier();
    PLSFilter std = (PLSFilter) plss.getFilter();
    std.setNumComponents(10);
    std.setPreprocessing(STANDARDIZE);

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        int essaySet = Contest.getEssaySet(dsn);

        Classifier alg = pls;

        if (essaySet == 10 || dsn.contains("1grams-thru-3grams")) {
            alg = plsc;
        }

        if (essaySet == 7) {
            alg = plss;
        }

        queue.add(Job.submit(
                new RegressionModelBuilder(dsn, "PLS", AbstractClassifier.makeCopy(alg), this.outputBucket)));
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:jjj.asap.sas.models1.job.BuildRegressionModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }//from www  .  j  a  v a2 s  .  c  o  m
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    List<Classifier> models = new ArrayList<Classifier>();

    LinearRegression m5 = new LinearRegression();
    m5.setAttributeSelectionMethod(M5);

    LinearRegression lr = new LinearRegression();
    lr.setAttributeSelectionMethod(NONE);

    RandomSubSpace rss = new RandomSubSpace();
    rss.setClassifier(lr);
    rss.setNumIterations(30);

    AdditiveRegression boostedStumps = new AdditiveRegression();
    boostedStumps.setClassifier(new DecisionStump());
    boostedStumps.setNumIterations(1000);

    AdditiveRegression boostedTrees = new AdditiveRegression();
    boostedTrees.setClassifier(new REPTree());
    boostedTrees.setNumIterations(100);

    models.add(m5);
    models.add(boostedStumps);
    models.add(boostedTrees);
    models.add(rss);

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        for (Classifier model : models) {

            String tag = null;
            if (model instanceof SingleClassifierEnhancer) {
                tag = model.getClass().getSimpleName() + "-"
                        + ((SingleClassifierEnhancer) model).getClassifier().getClass().getSimpleName();
            } else {
                tag = model.getClass().getSimpleName();
            }

            queue.add(Job.submit(new RegressionModelBuilder(dsn, tag, AbstractClassifier.makeCopy(model),
                    this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:jjj.asap.sas.models1.job.RGramModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }//from ww  w  .j  a  va  2s .  co  m
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    List<Classifier> models = new ArrayList<Classifier>();

    //SGD sgd = new SGD();
    //sgd.setDontNormalize(true);
    //sgd.setLossFunction(new SelectedTag(SGD.SQUAREDLOSS,SGD.TAGS_SELECTION));

    LinearRegression m5 = new LinearRegression();
    m5.setAttributeSelectionMethod(M5);

    //models.add(sgd);
    models.add(m5);

    LinearRegression lr = new LinearRegression();
    lr.setAttributeSelectionMethod(NONE);

    RandomSubSpace rss = new RandomSubSpace();
    rss.setClassifier(lr);
    rss.setNumIterations(30);

    models.add(rss);

    AdditiveRegression boostedStumps = new AdditiveRegression();
    boostedStumps.setClassifier(new DecisionStump());
    boostedStumps.setNumIterations(1000);

    AdditiveRegression boostedTrees = new AdditiveRegression();
    boostedTrees.setClassifier(new REPTree());
    boostedTrees.setNumIterations(100);

    models.add(boostedStumps);
    models.add(boostedTrees);

    models.add(new PLSClassifier());

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        for (Classifier model : models) {

            String tag = null;
            if (model instanceof SingleClassifierEnhancer) {
                tag = model.getClass().getSimpleName() + "-"
                        + ((SingleClassifierEnhancer) model).getClassifier().getClass().getSimpleName();
            } else {
                tag = model.getClass().getSimpleName();
            }

            queue.add(Job.submit(new RegressionModelBuilder(dsn, tag, AbstractClassifier.makeCopy(model),
                    this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}