Example usage for weka.classifiers.meta AdaBoostM1 buildClassifier

List of usage examples for weka.classifiers.meta AdaBoostM1 buildClassifier

Introduction

In this page you can find the example usage for weka.classifiers.meta AdaBoostM1 buildClassifier.

Prototype

public void buildClassifier(Instances data) throws Exception 

Source Link

Document

Method used to build the classifier.

Usage

From source file:classify.Classifier.java

/**
 * @param args the command line arguments
 *//*from   w  w w  . ja va 2  s .  c o  m*/
public static void main(String[] args) {
    //read in data
    try {
        DataSource input = new DataSource("no_missing_values.csv");
        Instances data = input.getDataSet();
        //Instances data = readFile("newfixed.txt");
        missingValuesRows(data);

        setAttributeValues(data);
        data.setClassIndex(data.numAttributes() - 1);

        //boosting
        AdaBoostM1 boosting = new AdaBoostM1();
        boosting.setNumIterations(25);
        boosting.setClassifier(new DecisionStump());

        //build the classifier
        boosting.buildClassifier(data);

        //evaluate using 10-fold cross validation
        Evaluation e1 = new Evaluation(data);
        e1.crossValidateModel(boosting, data, 10, new Random(1));

        DecimalFormat nf = new DecimalFormat("0.000");

        System.out.println("Results of Boosting with Decision Stumps:");
        System.out.println(boosting.toString());
        System.out.println("Results of Cross Validation:");
        System.out.println("Number of correctly classified instances: " + e1.correct() + " ("
                + nf.format(e1.pctCorrect()) + "%)");
        System.out.println("Number of incorrectly classified instances: " + e1.incorrect() + " ("
                + nf.format(e1.pctIncorrect()) + "%)");

        System.out.println("TP Rate: " + nf.format(e1.weightedTruePositiveRate() * 100) + "%");
        System.out.println("FP Rate: " + nf.format(e1.weightedFalsePositiveRate() * 100) + "%");
        System.out.println("Precision: " + nf.format(e1.weightedPrecision() * 100) + "%");
        System.out.println("Recall: " + nf.format(e1.weightedRecall() * 100) + "%");

        System.out.println();
        System.out.println("Confusion Matrix:");
        for (int i = 0; i < e1.confusionMatrix().length; i++) {
            for (int j = 0; j < e1.confusionMatrix()[0].length; j++) {
                System.out.print(e1.confusionMatrix()[i][j] + "   ");
            }
            System.out.println();
        }
        System.out.println();
        System.out.println();
        System.out.println();

        //logistic regression
        Logistic l = new Logistic();
        l.buildClassifier(data);

        e1 = new Evaluation(data);

        e1.crossValidateModel(l, data, 10, new Random(1));
        System.out.println("Results of Logistic Regression:");
        System.out.println(l.toString());
        System.out.println("Results of Cross Validation:");
        System.out.println("Number of correctly classified instances: " + e1.correct() + " ("
                + nf.format(e1.pctCorrect()) + "%)");
        System.out.println("Number of incorrectly classified instances: " + e1.incorrect() + " ("
                + nf.format(e1.pctIncorrect()) + "%)");

        System.out.println("TP Rate: " + nf.format(e1.weightedTruePositiveRate() * 100) + "%");
        System.out.println("FP Rate: " + nf.format(e1.weightedFalsePositiveRate() * 100) + "%");
        System.out.println("Precision: " + nf.format(e1.weightedPrecision() * 100) + "%");
        System.out.println("Recall: " + nf.format(e1.weightedRecall() * 100) + "%");

        System.out.println();
        System.out.println("Confusion Matrix:");
        for (int i = 0; i < e1.confusionMatrix().length; i++) {
            for (int j = 0; j < e1.confusionMatrix()[0].length; j++) {
                System.out.print(e1.confusionMatrix()[i][j] + "   ");
            }
            System.out.println();
        }

    } catch (Exception ex) {
        //data couldn't be read, so end program
        System.out.println("Exception thrown, program ending.");
    }
}

From source file:org.mcennis.graphrat.algorithm.machinelearning.MultiInstanceSVM.java

License:Open Source License

/**
 * Generate music predictions for a user as follows:
 * Calculate all artists A present in the data set.
 * Create a data set containing two numeric attributes (typically generated by the
 * AddBasicInterestLink and AddMusicLinks algorithms), a boolean for every artist
 * and a boolean class variable.  These fields are populated as follows
 * <br>//from  w w  w  . j  a va2  s.  c  o  m
 * For each artist, generate a 2-class classifier.
 * <br>
 * For every user, for every friend of the user:
 * First two fields are the interest and music link (0 if absent).
 * The artist fields are the music listened to by the friend
 * The final field is whether or not the user listens to the music specified.
 * 
 * For memory reasons, not all training data is used.  
 * FIXME: hard coded to a maximum 160 positive instances - should be a parameter
 */
public void execute(Graph g) {
    artists = g.getActor((String) parameter[2].getValue());
    java.util.Arrays.sort(artists);
    user = g.getActor((String) parameter[7].getValue());
    fireChange(Scheduler.SET_ALGORITHM_COUNT, artists.length);
    //        correctlyClassified = new int[user.length];
    //        totalClassified = new int[user.length];
    //        totalPresent = new int[user.length];
    //        java.util.Arrays.fill(correctlyClassified, 0);
    //        java.util.Arrays.fill(totalClassified, 0);
    //        java.util.Arrays.fill(totalPresent, 0);
    //        for (int i = 0; i < user.length; ++i) {
    //            Link[] given = g.getLinkBySource((String) parameter[3].getValue(), user[i]);
    //            if (given != null) {
    //                totalPresent[i] = given.length;
    //            }
    //        }
    int totalPerFile = countTotal(g);
    for (int i = 0; i < artists.length; ++i) {
        try {
            if (i % 10 == 0) {
                Logger.getLogger(MultiInstanceSVM.class.getName()).log(Level.INFO,
                        "Evaluating for artist " + artists[i].getID() + " " + i + " of " + artists.length);
                fireChange(Scheduler.SET_ALGORITHM_PROGRESS, i);
            }
            Instances dataSet = createDataSet(artists);
            int totalThisArtist = totalYes(g, artists[i]);
            int positiveSkipCount = 1;
            if ((((Integer) parameter[10].getValue()).intValue() != 0)
                    && (totalThisArtist > ((Integer) parameter[10].getValue()))) {
                positiveSkipCount = (totalThisArtist / 160) + 1;
            }
            if (totalThisArtist > 0) {
                int skipValue = (int) ((((Double) parameter[11].getValue()).doubleValue() * totalPerFile)
                        / (totalThisArtist / positiveSkipCount));
                if (skipValue <= 0) {
                    skipValue = 1;
                }
                if (!(Boolean) parameter[6].getValue()) {
                    skipValue = 1;
                }
                addInstances(g, dataSet, artists[i], skipValue, positiveSkipCount);
                //                    Classifier classifier = getClassifier();
                AdaBoostM1 classifier = new AdaBoostM1();
                try {
                    Logger.getLogger(MultiInstanceSVM.class.getName()).log(Level.FINER, "Building Classifier");
                    classifier.buildClassifier(dataSet);
                    Logger.getLogger(MultiInstanceSVM.class.getName()).log(Level.FINER, "Evaluating Classifer");
                    evaluateClassifier(classifier, dataSet, g, artists[i]);
                } catch (Exception ex) {
                    ex.printStackTrace();
                }
                classifier = null;
            } else {
                Logger.getLogger(MultiInstanceSVM.class.getName()).log(Level.WARNING,
                        "Artist '" + artists[i].getID() + "' has no users listening to them");
            }
            dataSet = null;
        } catch (java.lang.OutOfMemoryError e) {
            Logger.getLogger(MultiInstanceSVM.class.getName()).log(Level.WARNING,
                    "Artist " + artists[i].getID() + " (" + i + ") ran out of memory");
            //                System.gc();
        }
    }
    //        double precision = 0.0;
    //        double precisionSum = 0.0;
    //        double precisionSquared = 0.0;
    //        double recall = 0.0;
    //        double recallSum = 0.0;
    //        double recallSquared = 0.0;
    //        for (int i = 0; i < correctlyClassified.length; ++i) {
    //            if (totalClassified[i] > 0) {
    //                precision = ((double) correctlyClassified[i]) / ((double) totalClassified[i]);
    //            } else {
    //                precision = 0.0;
    //            }
    //            precisionSum += precision;
    //            precisionSquared += precision * precision;
    //        }
    //        for (int i = 0; i < totalPresent.length; ++i) {
    //            if (totalPresent[i] > 0) {
    //                recall = ((double) correctlyClassified[i]) / ((double) totalPresent[i]);
    //            } else {
    //                recall = 0;
    //            }
    //            recallSum += recall;
    //            recallSquared += recall * recall;
    //        }
    //        double sd = ((correctlyClassified.length * precisionSquared) - precisionSum * precisionSum) / correctlyClassified.length;
    //        double mean = precisionSum / correctlyClassified.length;
    //        System.out.println("Positive Precision\t" + mean);
    //        System.out.println("Positive Precision SD\t" + sd);
    //        sd = ((correctlyClassified.length * recallSquared) - recallSum * recallSum) / correctlyClassified.length;
    //        mean = recallSum / correctlyClassified.length;
    //        System.out.println("Positive Recall\t" + mean);
    //        System.out.println("Positive Recall SD\t" + sd);
}

From source file:org.mcennis.graphrat.algorithm.machinelearning.SVM.java

License:Open Source License

public void execute(Graph g) {
    artists = g.getActor((String) parameter[2].getValue());
    java.util.Arrays.sort(artists);
    user = g.getActor((String) parameter[7].getValue());
    fireChange(Scheduler.SET_ALGORITHM_COUNT, artists.length);
    int totalPerFile = user.length;
    for (int i = 0; i < artists.length; ++i) {
        try {/*from  w  w  w  .  j  av a  2  s . c o m*/
            if (i % 10 == 0) {
                System.out.println(
                        "Evaluating for artist " + artists[i].getID() + " " + i + " of " + artists.length);
                fireChange(Scheduler.SET_ALGORITHM_PROGRESS, i);
            }
            Instances dataSet = createDataSet(artists);
            int totalThisArtist = g.getLinkByDestination((String) parameter[3].getValue(), artists[i]).length;
            int positiveSkipCount = 1;
            if ((((Integer) parameter[10].getValue()).intValue() != 0)
                    && (totalThisArtist > ((Integer) parameter[10].getValue()))) {
                positiveSkipCount = (totalThisArtist / 160) + 1;
            }
            if (totalThisArtist > 0) {
                int skipValue = (int) ((((Double) parameter[11].getValue()).doubleValue() * totalPerFile)
                        / (totalThisArtist / positiveSkipCount));
                if (skipValue <= 0) {
                    skipValue = 1;
                }
                if (!(Boolean) parameter[6].getValue()) {
                    skipValue = 1;
                }
                addInstances(g, dataSet, artists[i], skipValue, positiveSkipCount);
                //                    Classifier classifier = getClassifier();
                AdaBoostM1 classifier = new AdaBoostM1();
                try {
                    //                        System.out.println("Building Classifier");
                    classifier.buildClassifier(dataSet);
                    //                        System.out.println("Evaluating Classifer");
                    evaluateClassifier(classifier, dataSet, g, artists[i]);
                } catch (Exception ex) {
                    ex.printStackTrace();
                }
                classifier = null;
            }
            dataSet = null;
        } catch (java.lang.OutOfMemoryError e) {
            System.err.println("Artist " + artists[i].getID() + " (" + i + ") ran out of memory");
            System.gc();
        }
    }
}