Example usage for weka.classifiers.functions LinearRegression LinearRegression

List of usage examples for weka.classifiers.functions LinearRegression LinearRegression

Introduction

In this page you can find the example usage for weka.classifiers.functions LinearRegression LinearRegression.

Prototype

public LinearRegression() 

Source Link

Usage

From source file:data.RegressionDrop.java

public void regression() throws Exception {

    //public static void main(String[] args) throws Exception{

    //load data//from   w ww. j  a  va 2 s .c o m
    Instances data = new Instances(new BufferedReader(new FileReader("NumOfDroppedByYear.arff")));
    data.setClassIndex(data.numAttributes() - 1);
    //build model
    LinearRegression model = new LinearRegression();
    model.buildClassifier(data);
    //the last instance with missing class is not used
    System.out.println(model);
    //classify the last instance
    Instance num = data.lastInstance();
    int people = (int) model.classifyInstance(num);
    System.out.println("NumOfDropped (" + num + "): " + people);
}

From source file:dkpro.similarity.experiments.sts2013.util.Evaluator.java

License:Open Source License

public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception {
    for (Dataset dataset : datasets) {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = new LinearRegression();

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                        + "-plusIDs.arff" });
        Instances data = DataSource.read(
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff");
        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);//from   www. j  a v  a  2s. c o m

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(classifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null) {
                predictedData = new Instances(pred, 0);
            }
            for (int j = 0; j < pred.numInstances(); j++) {
                predictedData.add(pred.instance(j));
            }
        }

        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            double value = predInst.value(predInst.attribute(valueIdx));

            scores[id] = value;

            // Limit to interval [0;5]
            if (scores[id] > 5.0) {
                scores[id] = 5.0;
            }
            if (scores[id] < 0.0) {
                scores[id] = 0.0;
            }
        }

        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
            sb.append(score.toString() + LF);
        }

        FileUtils.writeStringToFile(
                new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
                sb.toString());
    }
}

From source file:dkpro.similarity.experiments.sts2013baseline.util.Evaluator.java

License:Open Source License

public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception {
    for (Dataset dataset : datasets) {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = new LinearRegression();

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                        + "-plusIDs.arff" });

        String location = MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                + "-plusIDs.arff";

        Instances data = DataSource.read(location);

        if (data == null) {
            throw new IOException("Could not load data from: " + location);
        }/*  w w  w. j av  a 2s . com*/

        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(classifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null) {
                predictedData = new Instances(pred, 0);
            }
            for (int j = 0; j < pred.numInstances(); j++) {
                predictedData.add(pred.instance(j));
            }
        }

        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            double value = predInst.value(predInst.attribute(valueIdx));

            scores[id] = value;

            // Limit to interval [0;5]
            if (scores[id] > 5.0) {
                scores[id] = 5.0;
            }
            if (scores[id] < 0.0) {
                scores[id] = 0.0;
            }
        }

        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
            sb.append(score.toString() + LF);
        }

        FileUtils.writeStringToFile(
                new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
                sb.toString());
    }
}

From source file:edu.utexas.cs.tactex.utils.RegressionUtils.java

License:Open Source License

public static LinearRegression createLinearRegression() {
    LinearRegression linreg = new LinearRegression();
    linreg.setAttributeSelectionMethod(//from w  w w . j av  a  2 s  . com
            new SelectedTag(LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION));
    linreg.setEliminateColinearAttributes(false);
    // if wants debug info
    //linreg.setDebug(true);
    return linreg;
}

From source file:edu.washington.cs.knowitall.utilities.Classifier.java

License:Open Source License

protected void setupTraining(String trainingFilename) {
    // Set up the classifier from the training data.
    InputStreamReader trainingReader = null;
    classifier = new LinearRegression();
    try {//  w w w.j  a  v  a 2  s  .c  o m
        trainingReader = new InputStreamReader(new FileInputStream(trainingFilename));
        Instances trainingInstances = setupInstances(trainingReader);
        classifier.buildClassifier(trainingInstances);
    } catch (Exception e) {
        e.printStackTrace();
        System.exit(0);
    }

}

From source file:epsi.i5.datamining.Weka.java

public void excutionAlgo() throws FileNotFoundException, IOException, Exception {
    BufferedReader reader = new BufferedReader(new FileReader("src/epsi/i5/data/" + fileOne + ".arff"));
    Instances data = new Instances(reader);
    reader.close();/*from   w ww. j  av a2  s. co  m*/
    //System.out.println(data.attribute(0));
    data.setClass(data.attribute(0));
    NaiveBayes NB = new NaiveBayes();
    NB.buildClassifier(data);
    Evaluation naiveBayes = new Evaluation(data);
    naiveBayes.crossValidateModel(NB, data, 10, new Random(1));
    naiveBayes.evaluateModel(NB, data);
    //System.out.println(test.confusionMatrix() + "1");
    //System.out.println(test.correct() + "2");
    System.out.println("*****************************");
    System.out.println("******** Naive Bayes ********");
    System.out.println(naiveBayes.toMatrixString());
    System.out.println("*****************************");
    System.out.println("**** Pourcentage Correct ****");
    System.out.println(naiveBayes.pctCorrect());
    System.out.println("");
    J48 j = new J48();
    j.buildClassifier(data);
    Evaluation jeval = new Evaluation(data);
    jeval.crossValidateModel(j, data, 10, new Random(1));
    jeval.evaluateModel(j, data);
    System.out.println("*****************************");
    System.out.println("************ J48 ************");
    System.out.println(jeval.toMatrixString());
    System.out.println("*****************************");
    System.out.println("**** Pourcentage Correct ****");
    System.out.println(jeval.pctCorrect());
    System.out.println("");
    DecisionTable DT = new DecisionTable();
    DT.buildClassifier(data);
    Evaluation decisionTable = new Evaluation(data);
    decisionTable.crossValidateModel(DT, data, 10, new Random(1));
    decisionTable.evaluateModel(DT, data);
    System.out.println("*****************************");
    System.out.println("******* DecisionTable *******");
    System.out.println(decisionTable.toMatrixString());
    System.out.println("*****************************");
    System.out.println("**** Pourcentage Correct ****");
    System.out.println(decisionTable.pctCorrect());
    System.out.println("");
    OneR OR = new OneR();
    OR.buildClassifier(data);
    Evaluation oneR = new Evaluation(data);
    oneR.crossValidateModel(OR, data, 10, new Random(1));
    oneR.evaluateModel(OR, data);
    System.out.println("*****************************");
    System.out.println("************ OneR ***********");
    System.out.println(oneR.toMatrixString());
    System.out.println("*****************************");
    System.out.println("**** Pourcentage Correct ****");
    System.out.println(oneR.pctCorrect());

    //Polarit
    data.setClass(data.attribute(1));
    System.out.println("");
    M5Rules MR = new M5Rules();
    MR.buildClassifier(data);
    Evaluation m5rules = new Evaluation(data);
    m5rules.crossValidateModel(MR, data, 10, new Random(1));
    m5rules.evaluateModel(MR, data);
    System.out.println("*****************************");
    System.out.println("********** M5Rules **********");
    System.out.println(m5rules.correlationCoefficient());

    System.out.println("");
    LinearRegression LR = new LinearRegression();
    LR.buildClassifier(data);
    Evaluation linearR = new Evaluation(data);
    linearR.crossValidateModel(LR, data, 10, new Random(1));
    linearR.evaluateModel(LR, data);
    System.out.println("*****************************");
    System.out.println("********** linearR **********");
    System.out.println(linearR.correlationCoefficient());
}

From source file:GroupProject.DMChartUI.java

private void buildLinearModel() {
    CSVtoArff converter = new CSVtoArff();
    Instances students = null;//from  w  ww.j  a  v  a  2 s . c  o  m
    Instances students2 = null;
    try {
        converter.convert("studentTemp.csv", "studentTemp.arff");
    } catch (IOException ex) {
        Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    try {
        students = new Instances(new BufferedReader(new FileReader("studentTemp.arff")));
        students2 = new Instances(new BufferedReader(new FileReader("studentTemp.arff")));
    } catch (IOException ex) {
        Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
    }
    int target = dataSelector.getSelectedIndex() + 1;
    System.out.printf("this is the target: %d\n", target);
    if (target != 14 && target != 15 && target != 18 && target != 19) {
        System.out.println("Please select a numerical category");
        equationDisplayArea.setText("Please select a numerical category");
        return;
    }

    //set target 
    students.setClassIndex(target);
    students2.setClassIndex(target);

    Remove remove = new Remove();
    int[] toremove = { 0, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17 };
    remove.setAttributeIndicesArray(toremove);

    try {
        remove.setInputFormat(students);
    } catch (Exception ex) {
        Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    Instances instNew = null;
    try {
        instNew = Filter.useFilter(students, remove);
    } catch (Exception ex) {
        Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
    }

    //System.out.println(instNew);

    Lmodel = new LinearRegression();
    try {
        System.out.println("im building the model");
        Lmodel.buildClassifier(instNew);
        System.out.println("I finished building the model");
    } catch (Exception ex) {
        Logger.getLogger(DMChartUI.class.getName()).log(Level.SEVERE, null, ex);
    }
    equationDisplayArea.setText(Lmodel.toString());

}

From source file:jjj.asap.sas.models1.job.BuildRegressionModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }//from w  ww.j  ava  2s .c  o m
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    List<Classifier> models = new ArrayList<Classifier>();

    LinearRegression m5 = new LinearRegression();
    m5.setAttributeSelectionMethod(M5);

    LinearRegression lr = new LinearRegression();
    lr.setAttributeSelectionMethod(NONE);

    RandomSubSpace rss = new RandomSubSpace();
    rss.setClassifier(lr);
    rss.setNumIterations(30);

    AdditiveRegression boostedStumps = new AdditiveRegression();
    boostedStumps.setClassifier(new DecisionStump());
    boostedStumps.setNumIterations(1000);

    AdditiveRegression boostedTrees = new AdditiveRegression();
    boostedTrees.setClassifier(new REPTree());
    boostedTrees.setNumIterations(100);

    models.add(m5);
    models.add(boostedStumps);
    models.add(boostedTrees);
    models.add(rss);

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        for (Classifier model : models) {

            String tag = null;
            if (model instanceof SingleClassifierEnhancer) {
                tag = model.getClass().getSimpleName() + "-"
                        + ((SingleClassifierEnhancer) model).getClassifier().getClass().getSimpleName();
            } else {
                tag = model.getClass().getSimpleName();
            }

            queue.add(Job.submit(new RegressionModelBuilder(dsn, tag, AbstractClassifier.makeCopy(model),
                    this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:jjj.asap.sas.models1.job.RGramModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }/*  ww  w  .  jav a  2s  . co  m*/
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    List<Classifier> models = new ArrayList<Classifier>();

    //SGD sgd = new SGD();
    //sgd.setDontNormalize(true);
    //sgd.setLossFunction(new SelectedTag(SGD.SQUAREDLOSS,SGD.TAGS_SELECTION));

    LinearRegression m5 = new LinearRegression();
    m5.setAttributeSelectionMethod(M5);

    //models.add(sgd);
    models.add(m5);

    LinearRegression lr = new LinearRegression();
    lr.setAttributeSelectionMethod(NONE);

    RandomSubSpace rss = new RandomSubSpace();
    rss.setClassifier(lr);
    rss.setNumIterations(30);

    models.add(rss);

    AdditiveRegression boostedStumps = new AdditiveRegression();
    boostedStumps.setClassifier(new DecisionStump());
    boostedStumps.setNumIterations(1000);

    AdditiveRegression boostedTrees = new AdditiveRegression();
    boostedTrees.setClassifier(new REPTree());
    boostedTrees.setNumIterations(100);

    models.add(boostedStumps);
    models.add(boostedTrees);

    models.add(new PLSClassifier());

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        for (Classifier model : models) {

            String tag = null;
            if (model instanceof SingleClassifierEnhancer) {
                tag = model.getClass().getSimpleName() + "-"
                        + ((SingleClassifierEnhancer) model).getClassifier().getClass().getSimpleName();
            } else {
                tag = model.getClass().getSimpleName();
            }

            queue.add(Job.submit(new RegressionModelBuilder(dsn, tag, AbstractClassifier.makeCopy(model),
                    this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:meka.classifiers.multilabel.Maniac.java

License:Open Source License

/**
 * Change default classifier to CR with Linear Regression as base as this classifier
 * uses numeric values in the compressed labels.
 *//*from  w  w w  .  ja va 2 s  .c  o m*/
protected Classifier getDefaultClassifier() {
    CR cr = new CR();
    LinearRegression lr = new LinearRegression();
    cr.setClassifier(lr);
    return cr;
}