Example usage for weka.core Instances numAttributes

List of usage examples for weka.core Instances numAttributes

Introduction

In this page you can find the example usage for weka.core Instances numAttributes.

Prototype


publicint numAttributes() 

Source Link

Document

Returns the number of attributes.

Usage

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Evaluates a classifier with the options given in an array of
 * strings. <p/>/*from   w ww  .  ja  v a2  s . co m*/
 *
 * Valid options are: <p/>
 *
 * -t name of training file <br/>
 * Name of the file with the training data. (required) <p/>
 *
 * -T name of test file <br/>
 * Name of the file with the test data. If missing a cross-validation 
 * is performed. <p/>
 *
 * -c class index <br/>
 * Index of the class attribute (1, 2, ...; default: last). <p/>
 *
 * -x number of folds <br/>
 * The number of folds for the cross-validation (default: 10). <p/>
 *
 * -no-cv <br/>
 * No cross validation.  If no test file is provided, no evaluation
 * is done. <p/>
 * 
 * -split-percentage percentage <br/>
 * Sets the percentage for the train/test set split, e.g., 66. <p/>
 * 
 * -preserve-order <br/>
 * Preserves the order in the percentage split instead of randomizing
 * the data first with the seed value ('-s'). <p/>
 *
 * -s seed <br/>
 * Random number seed for the cross-validation and percentage split
 * (default: 1). <p/>
 *
 * -m file with cost matrix <br/>
 * The name of a file containing a cost matrix. <p/>
 *
 * -l filename <br/>
 * Loads classifier from the given file. In case the filename ends with
 * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
 *
 * -d filename <br/>
 * Saves classifier built from the training data into the given file. In case 
 * the filename ends with ".xml" the options are saved XML, not the model. <p/>
 *
 * -v <br/>
 * Outputs no statistics for the training data. <p/>
 *
 * -o <br/>
 * Outputs statistics only, not the classifier. <p/>
 * 
 * -i <br/>
 * Outputs detailed information-retrieval statistics per class. <p/>
 *
 * -k <br/>
 * Outputs information-theoretic statistics. <p/>
 *
 * -p range <br/>
 * Outputs predictions for test instances (or the train instances if no test
 * instances provided and -no-cv is used), along with the attributes in the specified range 
 * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
 *
 * -distribution <br/>
 * Outputs the distribution instead of only the prediction
 * in conjunction with the '-p' option (only nominal classes). <p/>
 *
 * -r <br/>
 * Outputs cumulative margin distribution (and nothing else). <p/>
 *
 * -g <br/> 
 * Only for classifiers that implement "Graphable." Outputs
 * the graph representation of the classifier (and nothing
 * else). <p/>
 *
 * -xml filename | xml-string <br/>
 * Retrieves the options from the XML-data instead of the command line. <p/>
 *
 * @param classifier machine learning classifier
 * @param options the array of string containing the options
 * @throws Exception if model could not be evaluated successfully
 * @return a string describing the results 
 */
public static String evaluateModel(Classifier classifier, String[] options) throws Exception {

    Instances train = null, tempTrain, test = null, template = null;
    int seed = 1, folds = 10, classIndex = -1;
    boolean noCrossValidation = false;
    String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString,
            objectInputFileName, objectOutputFileName, attributeRangeString;
    boolean noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false,
            printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false;
    StringBuffer text = new StringBuffer();
    DataSource trainSource = null, testSource = null;
    ObjectInputStream objectInputStream = null;
    BufferedInputStream xmlInputStream = null;
    CostMatrix costMatrix = null;
    StringBuffer schemeOptionsText = null;
    Range attributesToOutput = null;
    long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0;
    String xml = "";
    String[] optionsTmp = null;
    Classifier classifierBackup;
    Classifier classifierClassifications = null;
    boolean printDistribution = false;
    int actualClassIndex = -1; // 0-based class index
    String splitPercentageString = "";
    int splitPercentage = -1;
    boolean preserveOrder = false;
    boolean trainSetPresent = false;
    boolean testSetPresent = false;
    String thresholdFile;
    String thresholdLabel;
    StringBuffer predsBuff = null; // predictions from cross-validation

    // help requested?
    if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {

        // global info requested as well?
        boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options);

        throw new Exception("\nHelp requested." + makeOptionString(classifier, globalInfo));
    }

    try {
        // do we get the input from XML instead of normal parameters?
        xml = Utils.getOption("xml", options);
        if (!xml.equals(""))
            options = new XMLOptions(xml).toArray();

        // is the input model only the XML-Options, i.e. w/o built model?
        optionsTmp = new String[options.length];
        for (int i = 0; i < options.length; i++)
            optionsTmp[i] = options[i];

        String tmpO = Utils.getOption('l', optionsTmp);
        //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
        if (tmpO.endsWith(".xml")) {
            // try to load file as PMML first
            boolean success = false;
            try {
                //PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO);
                //if (pmmlModel instanceof PMMLClassifier) {
                //classifier = ((PMMLClassifier)pmmlModel);
                // success = true;
                //}
            } catch (IllegalArgumentException ex) {
                success = false;
            }
            if (!success) {
                // load options from serialized data  ('-l' is automatically erased!)
                XMLClassifier xmlserial = new XMLClassifier();
                Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));

                // merge options
                optionsTmp = new String[options.length + cl.getOptions().length];
                System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
                System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
                options = optionsTmp;
            }
        }

        noCrossValidation = Utils.getFlag("no-cv", options);
        // Get basic options (options the same for all schemes)
        classIndexString = Utils.getOption('c', options);
        if (classIndexString.length() != 0) {
            if (classIndexString.equals("first"))
                classIndex = 1;
            else if (classIndexString.equals("last"))
                classIndex = -1;
            else
                classIndex = Integer.parseInt(classIndexString);
        }
        trainFileName = Utils.getOption('t', options);
        objectInputFileName = Utils.getOption('l', options);
        objectOutputFileName = Utils.getOption('d', options);
        testFileName = Utils.getOption('T', options);
        foldsString = Utils.getOption('x', options);
        if (foldsString.length() != 0) {
            folds = Integer.parseInt(foldsString);
        }
        seedString = Utils.getOption('s', options);
        if (seedString.length() != 0) {
            seed = Integer.parseInt(seedString);
        }
        if (trainFileName.length() == 0) {
            if (objectInputFileName.length() == 0) {
                throw new Exception("No training file and no object " + "input file given.");
            }
            if (testFileName.length() == 0) {
                throw new Exception("No training file and no test " + "file given.");
            }
        } else if ((objectInputFileName.length() != 0)
                && ((!(classifier instanceof UpdateableClassifier)) || (testFileName.length() == 0))) {
            throw new Exception("Classifier not incremental, or no " + "test file provided: can't "
                    + "use both train and model file.");
        }
        try {
            if (trainFileName.length() != 0) {
                trainSetPresent = true;
                trainSource = new DataSource(trainFileName);
            }
            if (testFileName.length() != 0) {
                testSetPresent = true;
                testSource = new DataSource(testFileName);
            }
            if (objectInputFileName.length() != 0) {
                if (objectInputFileName.endsWith(".xml")) {
                    // if this is the case then it means that a PMML classifier was
                    // successfully loaded earlier in the code
                    objectInputStream = null;
                    xmlInputStream = null;
                } else {
                    InputStream is = new FileInputStream(objectInputFileName);
                    if (objectInputFileName.endsWith(".gz")) {
                        is = new GZIPInputStream(is);
                    }
                    // load from KOML?
                    if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent())) {
                        objectInputStream = new ObjectInputStream(is);
                        xmlInputStream = null;
                    } else {
                        objectInputStream = null;
                        xmlInputStream = new BufferedInputStream(is);
                    }
                }
            }
        } catch (Exception e) {
            throw new Exception("Can't open file " + e.getMessage() + '.');
        }
        if (testSetPresent) {
            template = test = testSource.getStructure();
            if (classIndex != -1) {
                test.setClassIndex(classIndex - 1);
            } else {
                if ((test.classIndex() == -1) || (classIndexString.length() != 0))
                    test.setClassIndex(test.numAttributes() - 1);
            }
            actualClassIndex = test.classIndex();
        } else {
            // percentage split
            splitPercentageString = Utils.getOption("split-percentage", options);
            if (splitPercentageString.length() != 0) {
                if (foldsString.length() != 0)
                    throw new Exception("Percentage split cannot be used in conjunction with "
                            + "cross-validation ('-x').");
                splitPercentage = Integer.parseInt(splitPercentageString);
                if ((splitPercentage <= 0) || (splitPercentage >= 100))
                    throw new Exception("Percentage split value needs be >0 and <100.");
            } else {
                splitPercentage = -1;
            }
            preserveOrder = Utils.getFlag("preserve-order", options);
            if (preserveOrder) {
                if (splitPercentage == -1)
                    throw new Exception("Percentage split ('-percentage-split') is missing.");
            }
            // create new train/test sources
            if (splitPercentage > 0) {
                testSetPresent = true;
                Instances tmpInst = trainSource.getDataSet(actualClassIndex);
                if (!preserveOrder)
                    tmpInst.randomize(new Random(seed));
                int trainSize = tmpInst.numInstances() * splitPercentage / 100;
                int testSize = tmpInst.numInstances() - trainSize;
                Instances trainInst = new Instances(tmpInst, 0, trainSize);
                Instances testInst = new Instances(tmpInst, trainSize, testSize);
                trainSource = new DataSource(trainInst);
                testSource = new DataSource(testInst);
                template = test = testSource.getStructure();
                if (classIndex != -1) {
                    test.setClassIndex(classIndex - 1);
                } else {
                    if ((test.classIndex() == -1) || (classIndexString.length() != 0))
                        test.setClassIndex(test.numAttributes() - 1);
                }
                actualClassIndex = test.classIndex();
            }
        }
        if (trainSetPresent) {
            template = train = trainSource.getStructure();
            if (classIndex != -1) {
                train.setClassIndex(classIndex - 1);
            } else {
                if ((train.classIndex() == -1) || (classIndexString.length() != 0))
                    train.setClassIndex(train.numAttributes() - 1);
            }
            actualClassIndex = train.classIndex();
            if ((testSetPresent) && !test.equalHeaders(train)) {
                throw new IllegalArgumentException("Train and test file not compatible!");
            }
        }
        if (template == null) {
            throw new Exception("No actual dataset provided to use as template");
        }
        costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses());

        classStatistics = Utils.getFlag('i', options);
        noOutput = Utils.getFlag('o', options);
        trainStatistics = !Utils.getFlag('v', options);
        printComplexityStatistics = Utils.getFlag('k', options);
        printMargins = Utils.getFlag('r', options);
        printGraph = Utils.getFlag('g', options);
        sourceClass = Utils.getOption('z', options);
        printSource = (sourceClass.length() != 0);
        printDistribution = Utils.getFlag("distribution", options);
        thresholdFile = Utils.getOption("threshold-file", options);
        thresholdLabel = Utils.getOption("threshold-label", options);

        // Check -p option
        try {
            attributeRangeString = Utils.getOption('p', options);
        } catch (Exception e) {
            throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. "
                    + "It now expects a parameter specifying a range of attributes "
                    + "to list with the predictions. Use '-p 0' for none.");
        }
        if (attributeRangeString.length() != 0) {
            printClassifications = true;
            noOutput = true;
            if (!attributeRangeString.equals("0"))
                attributesToOutput = new Range(attributeRangeString);
        }

        if (!printClassifications && printDistribution)
            throw new Exception("Cannot print distribution without '-p' option!");

        // if no training file given, we don't have any priors
        if ((!trainSetPresent) && (printComplexityStatistics))
            throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");

        // If a model file is given, we can't process 
        // scheme-specific options
        if (objectInputFileName.length() != 0) {
            Utils.checkForRemainingOptions(options);
        } else {

            // Set options for classifier
            if (classifier instanceof OptionHandler) {
                for (int i = 0; i < options.length; i++) {
                    if (options[i].length() != 0) {
                        if (schemeOptionsText == null) {
                            schemeOptionsText = new StringBuffer();
                        }
                        if (options[i].indexOf(' ') != -1) {
                            schemeOptionsText.append('"' + options[i] + "\" ");
                        } else {
                            schemeOptionsText.append(options[i] + " ");
                        }
                    }
                }
                ((OptionHandler) classifier).setOptions(options);
            }
        }
        Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
        throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier, false));
    }

    // Setup up evaluation objects
    Evaluation_D trainingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix);
    Evaluation_D testingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix);

    // disable use of priors if no training file given
    if (!trainSetPresent)
        testingEvaluation.useNoPriors();

    if (objectInputFileName.length() != 0) {
        // Load classifier from file
        if (objectInputStream != null) {
            classifier = (Classifier) objectInputStream.readObject();
            // try and read a header (if present)
            Instances savedStructure = null;
            try {
                savedStructure = (Instances) objectInputStream.readObject();
            } catch (Exception ex) {
                // don't make a fuss
            }
            if (savedStructure != null) {
                // test for compatibility with template
                if (!template.equalHeaders(savedStructure)) {
                    throw new Exception("training and test set are not compatible");
                }
            }
            objectInputStream.close();
        } else if (xmlInputStream != null) {
            // whether KOML is available has already been checked (objectInputStream would null otherwise)!
            classifier = (Classifier) KOML.read(xmlInputStream);
            xmlInputStream.close();
        }
    }

    // backup of fully setup classifier for cross-validation
    classifierBackup = Classifier.makeCopy(classifier);

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) && (testSetPresent || noCrossValidation)
            && (costMatrix == null) && (trainSetPresent)) {
        // Build classifier incrementally
        trainingEvaluation.setPriors(train);
        testingEvaluation.setPriors(train);
        trainTimeStart = System.currentTimeMillis();
        if (objectInputFileName.length() == 0) {
            classifier.buildClassifier(train);
        }
        Instance trainInst;
        while (trainSource.hasMoreElements(train)) {
            trainInst = trainSource.nextElement(train);
            trainingEvaluation.updatePriors(trainInst);
            testingEvaluation.updatePriors(trainInst);
            ((UpdateableClassifier) classifier).updateClassifier(trainInst);
        }
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    } else if (objectInputFileName.length() == 0) {
        // Build classifier in one go
        tempTrain = trainSource.getDataSet(actualClassIndex);
        trainingEvaluation.setPriors(tempTrain);
        testingEvaluation.setPriors(tempTrain);
        trainTimeStart = System.currentTimeMillis();
        classifier.buildClassifier(tempTrain);
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

    // backup of fully trained classifier for printing the classifications
    if (printClassifications)
        classifierClassifications = Classifier.makeCopy(classifier);

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
        OutputStream os = new FileOutputStream(objectOutputFileName);
        // binary
        if (!(objectOutputFileName.endsWith(".xml")
                || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
            if (objectOutputFileName.endsWith(".gz")) {
                os = new GZIPOutputStream(os);
            }
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
            objectOutputStream.writeObject(classifier);
            if (template != null) {
                objectOutputStream.writeObject(template);
            }
            objectOutputStream.flush();
            objectOutputStream.close();
        }
        // KOML/XML
        else {
            BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
            if (objectOutputFileName.endsWith(".xml")) {
                XMLSerialization xmlSerial = new XMLClassifier();
                xmlSerial.write(xmlOutputStream, classifier);
            } else
            // whether KOML is present has already been checked
            // if not present -> ".koml" is interpreted as binary - see above
            if (objectOutputFileName.endsWith(".koml")) {
                KOML.write(xmlOutputStream, classifier);
            }
            xmlOutputStream.close();
        }
    }

    // If classifier is drawable output string describing graph
    if ((classifier instanceof Drawable) && (printGraph)) {
        return ((Drawable) classifier).graph();
    }

    // Output the classifier as equivalent source
    if ((classifier instanceof Sourcable) && (printSource)) {
        return wekaStaticWrapper((Sourcable) classifier, sourceClass);
    }

    // Output model
    if (!(noOutput || printMargins)) {
        if (classifier instanceof OptionHandler) {
            if (schemeOptionsText != null) {
                text.append("\nOptions: " + schemeOptionsText);
                text.append("\n");
            }
        }
        text.append("\n" + classifier.toString() + "\n");
    }

    if (!printMargins && (costMatrix != null)) {
        text.append("\n=== Evaluation Cost Matrix ===\n\n");
        text.append(costMatrix.toString());
    }

    // Output test instance predictions only
    if (printClassifications) {
        DataSource source = testSource;
        predsBuff = new StringBuffer();
        // no test set -> use train set
        if (source == null && noCrossValidation) {
            source = trainSource;
            predsBuff.append("\n=== Predictions on training data ===\n\n");
        } else {
            predsBuff.append("\n=== Predictions on test data ===\n\n");
        }
        if (source != null) {
            /*      return printClassifications(classifierClassifications, new Instances(template, 0),
                    source, actualClassIndex + 1, attributesToOutput,
                    printDistribution); */
            printClassifications(classifierClassifications, new Instances(template, 0), source,
                    actualClassIndex + 1, attributesToOutput, printDistribution, predsBuff);
            //        return predsText.toString();
        }
    }

    // Compute error estimate from training data
    if ((trainStatistics) && (trainSetPresent)) {

        if ((classifier instanceof UpdateableClassifier) && (testSetPresent) && (costMatrix == null)) {

            // Classifier was trained incrementally, so we have to 
            // reset the source.
            trainSource.reset();

            // Incremental testing
            train = trainSource.getStructure(actualClassIndex);
            testTimeStart = System.currentTimeMillis();
            Instance trainInst;
            while (trainSource.hasMoreElements(train)) {
                trainInst = trainSource.nextElement(train);
                trainingEvaluation.evaluateModelOnce((Classifier) classifier, trainInst);
            }
            testTimeElapsed = System.currentTimeMillis() - testTimeStart;
        } else {
            testTimeStart = System.currentTimeMillis();
            trainingEvaluation.evaluateModel(classifier, trainSource.getDataSet(actualClassIndex));
            testTimeElapsed = System.currentTimeMillis() - testTimeStart;
        }

        // Print the results of the training evaluation
        if (printMargins) {
            return trainingEvaluation.toCumulativeMarginDistributionString();
        } else {
            if (!printClassifications) {
                text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2)
                        + " seconds");

                if (splitPercentage > 0)
                    text.append("\nTime taken to test model on training split: ");
                else
                    text.append("\nTime taken to test model on training data: ");
                text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds");

                if (splitPercentage > 0)
                    text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " split ===\n",
                            printComplexityStatistics));
                else
                    text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n",
                            printComplexityStatistics));

                if (template.classAttribute().isNominal()) {
                    if (classStatistics) {
                        text.append("\n\n" + trainingEvaluation.toClassDetailsString());
                    }
                    if (!noCrossValidation)
                        text.append("\n\n" + trainingEvaluation.toMatrixString());
                }
            }
        }
    }

    // Compute proper error estimates
    if (testSource != null) {
        // Testing is on the supplied test data
        testSource.reset();
        test = testSource.getStructure(test.classIndex());
        Instance testInst;
        while (testSource.hasMoreElements(test)) {
            testInst = testSource.nextElement(test);
            testingEvaluation.evaluateModelOnceAndRecordPrediction((Classifier) classifier, testInst);
        }

        if (splitPercentage > 0) {
            if (!printClassifications) {
                text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test split ===\n",
                        printComplexityStatistics));
            }
        } else {
            if (!printClassifications) {
                text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n",
                        printComplexityStatistics));
            }
        }

    } else if (trainSource != null) {
        if (!noCrossValidation) {
            // Testing is via cross-validation on training data
            Random random = new Random(seed);
            // use untrained (!) classifier for cross-validation
            classifier = Classifier.makeCopy(classifierBackup);
            if (!printClassifications) {
                testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex),
                        folds, random);
                if (template.classAttribute().isNumeric()) {
                    text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Cross-validation ===\n",
                            printComplexityStatistics));
                } else {
                    text.append("\n\n\n" + testingEvaluation.toSummaryString(
                            "=== Stratified " + "cross-validation ===\n", printComplexityStatistics));
                }
            } else {
                predsBuff = new StringBuffer();
                predsBuff.append("\n=== Predictions under cross-validation ===\n\n");
                testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex),
                        folds, random, predsBuff, attributesToOutput, new Boolean(printDistribution));
                /*          if (template.classAttribute().isNumeric()) {
                            text.append("\n\n\n" + testingEvaluation.
                toSummaryString("=== Cross-validation ===\n",
                                printComplexityStatistics));
                          } else {
                            text.append("\n\n\n" + testingEvaluation.
                toSummaryString("=== Stratified " + 
                                "cross-validation ===\n",
                                printComplexityStatistics));
                          } */
            }
        }
    }
    if (template.classAttribute().isNominal()) {
        if (classStatistics && !noCrossValidation && !printClassifications) {
            text.append("\n\n" + testingEvaluation.toClassDetailsString());
        }
        if (!noCrossValidation && !printClassifications)
            text.append("\n\n" + testingEvaluation.toMatrixString());

    }

    // predictions from cross-validation?
    if (predsBuff != null) {
        text.append("\n" + predsBuff);
    }

    if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) {
        int labelIndex = 0;
        if (thresholdLabel.length() != 0)
            labelIndex = template.classAttribute().indexOfValue(thresholdLabel);
        if (labelIndex == -1)
            throw new IllegalArgumentException("Class label '" + thresholdLabel + "' is unknown!");
        ThresholdCurve tc = new ThresholdCurve();
        Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex);
        DataSink.write(thresholdFile, result);
    }

    return text.toString();
}

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Prints the header for the predictions output into a supplied StringBuffer
 *
 * @param test structure of the test set to print predictions for
 * @param attributesToOutput indices of the attributes to output
 * @param printDistribution prints the complete distribution for nominal
 * attributes, not just the predicted value
 * @param text the StringBuffer to print to
 *//*from  w  w w  . j  av a2  s .co  m*/
protected static void printClassificationsHeader(Instances test, Range attributesToOutput,
        boolean printDistribution, StringBuffer text) {
    // print header
    if (test.classAttribute().isNominal())
        if (printDistribution)
            text.append(" inst#     actual  predicted error distribution");
        else
            text.append(" inst#     actual  predicted error prediction");
    else
        text.append(" inst#     actual  predicted      error");
    if (attributesToOutput != null) {
        attributesToOutput.setUpper(test.numAttributes() - 1);
        text.append(" (");
        boolean first = true;
        for (int i = 0; i < test.numAttributes(); i++) {
            if (i == test.classIndex())
                continue;

            if (attributesToOutput.isInRange(i)) {
                if (!first)
                    text.append(",");
                text.append(test.attribute(i).name());
                first = false;
            }
        }
        text.append(")");
    }
    text.append("\n");
}

From source file:cotraining.copy.Evaluation_D.java

License:Open Source License

/**
 * Prints the predictions for the given dataset into a supplied StringBuffer
 * //from  w  ww  .ja va2  s .  c  o  m
 * @param classifier      the classifier to use
 * @param train      the training data
 * @param testSource      the test set
 * @param classIndex      the class index (1-based), if -1 ot does not 
 *             override the class index is stored in the data 
 *             file (by using the last attribute)
 * @param attributesToOutput   the indices of the attributes to output
 * @param printDistribution   prints the complete distribution for nominal 
 *             classes, not just the predicted value
 * @param text                StringBuffer to hold the printed predictions
 * @throws Exception       if test file cannot be opened
 */
public static void printClassifications(Classifier classifier, Instances train, DataSource testSource,
        int classIndex, Range attributesToOutput, boolean printDistribution, StringBuffer text)
        throws Exception {

    if (testSource != null) {
        Instances test = testSource.getStructure();
        if (classIndex != -1) {
            test.setClassIndex(classIndex - 1);
        } else {
            if (test.classIndex() == -1)
                test.setClassIndex(test.numAttributes() - 1);
        }

        // print the header
        printClassificationsHeader(test, attributesToOutput, printDistribution, text);

        // print predictions
        int i = 0;
        testSource.reset();
        test = testSource.getStructure(test.classIndex());
        while (testSource.hasMoreElements(test)) {
            Instance inst = testSource.nextElement(test);
            text.append(predictionText(classifier, inst, i, attributesToOutput, printDistribution));
            i++;
        }
    }
    //    return text.toString();
}

From source file:cs.man.ac.uk.predict.Predictor.java

License:Open Source License

public static void makePredictionsEnsembleNew(String trainPath, String testPath, String resultPath) {
    System.out.println("Training set: " + trainPath);
    System.out.println("Test set: " + testPath);

    /**/*  w w w .j av  a  2s  .c o  m*/
     * The ensemble classifiers. This is a heterogeneous ensemble.
     */
    J48 learner1 = new J48();
    SMO learner2 = new SMO();
    NaiveBayes learner3 = new NaiveBayes();
    MultilayerPerceptron learner5 = new MultilayerPerceptron();

    System.out.println("Training Ensemble.");
    long startTime = System.nanoTime();
    try {
        BufferedReader reader = new BufferedReader(new FileReader(trainPath));
        Instances data = new Instances(reader);
        data.setClassIndex(data.numAttributes() - 1);
        System.out.println("Training data length: " + data.numInstances());

        learner1.buildClassifier(data);
        learner2.buildClassifier(data);
        learner3.buildClassifier(data);
        learner5.buildClassifier(data);

        long endTime = System.nanoTime();
        long nanoseconds = endTime - startTime;
        double seconds = (double) nanoseconds / 1000000000.0;
        System.out.println("Training Ensemble completed in " + nanoseconds + " (ns) or " + seconds + " (s).");
    } catch (IOException e) {
        System.out.println("Could not train Ensemble classifier IOException on training data file.");
    } catch (Exception e) {
        System.out.println("Could not train Ensemble classifier Exception building model.");
    }

    try {
        String line = "";

        // Read the file and display it line by line. 
        BufferedReader in = null;

        // Read in and store each positive prediction in the tree map.
        try {
            //open stream to file
            in = new BufferedReader(new FileReader(testPath));

            while ((line = in.readLine()) != null) {
                if (line.toLowerCase().contains("@data"))
                    break;
            }
        } catch (Exception e) {
        }

        // A different ARFF loader used here (compared to above) as
        // the ARFF file may be extremely large. In which case the whole
        // file cannot be read in. Instead it is read in incrementally.
        ArffLoader loader = new ArffLoader();
        loader.setFile(new File(testPath));

        Instances data = loader.getStructure();
        data.setClassIndex(data.numAttributes() - 1);

        System.out.println("Ensemble Classifier is ready.");
        System.out.println("Testing on all instances avaialable.");

        startTime = System.nanoTime();

        int instanceNumber = 0;

        // label instances
        Instance current;

        while ((current = loader.getNextInstance(data)) != null) {
            instanceNumber += 1;
            line = in.readLine();

            double classification1 = learner1.classifyInstance(current);
            double classification2 = learner2.classifyInstance(current);
            double classification3 = learner3.classifyInstance(current);
            double classification5 = learner5.classifyInstance(current);

            // All classifiers must agree. This is a very primitive ensemble strategy!
            if (classification1 == 1 && classification2 == 1 && classification3 == 1 && classification5 == 1) {
                if (line != null) {
                    //System.out.println("Instance: "+instanceNumber+"\t"+line);
                    //System.in.read();
                }
                Writer.append(resultPath, instanceNumber + "\n");
            }
        }

        in.close();

        System.out.println("Test set instances: " + instanceNumber);

        long endTime = System.nanoTime();
        long duration = endTime - startTime;
        double seconds = (double) duration / 1000000000.0;

        System.out.println("Testing Ensemble completed in " + duration + " (ns) or " + seconds + " (s).");
    } catch (Exception e) {
        System.out.println("Could not test Ensemble classifier due to an error.");
    }
}

From source file:cs.man.ac.uk.predict.Predictor.java

License:Open Source License

public static void makePredictionsEnsembleStream(String trainPath, String testPath, String resultPath) {
    System.out.println("Training set: " + trainPath);
    System.out.println("Test set: " + testPath);

    /**//from   w w  w  .  j av a 2s . c  om
     * The ensemble classifiers. This is a heterogeneous ensemble.
     */
    J48 learner1 = new J48();
    SMO learner2 = new SMO();
    NaiveBayes learner3 = new NaiveBayes();
    MultilayerPerceptron learner5 = new MultilayerPerceptron();

    System.out.println("Training Ensemble.");
    long startTime = System.nanoTime();
    try {
        BufferedReader reader = new BufferedReader(new FileReader(trainPath));
        Instances data = new Instances(reader);
        data.setClassIndex(data.numAttributes() - 1);
        System.out.println("Training data length: " + data.numInstances());

        learner1.buildClassifier(data);
        learner2.buildClassifier(data);
        learner3.buildClassifier(data);
        learner5.buildClassifier(data);

        long endTime = System.nanoTime();
        long nanoseconds = endTime - startTime;
        double seconds = (double) nanoseconds / 1000000000.0;
        System.out.println("Training Ensemble completed in " + nanoseconds + " (ns) or " + seconds + " (s).");
    } catch (IOException e) {
        System.out.println("Could not train Ensemble classifier IOException on training data file.");
    } catch (Exception e) {
        System.out.println("Could not train Ensemble classifier Exception building model.");
    }

    try {
        // A different ARFF loader used here (compared to above) as
        // the ARFF file may be extremely large. In which case the whole
        // file cannot be read in. Instead it is read in incrementally.
        ArffLoader loader = new ArffLoader();
        loader.setFile(new File(testPath));

        Instances data = loader.getStructure();
        data.setClassIndex(data.numAttributes() - 1);

        System.out.println("Ensemble Classifier is ready.");
        System.out.println("Testing on all instances avaialable.");

        startTime = System.nanoTime();

        int instanceNumber = 0;

        // label instances
        Instance current;

        while ((current = loader.getNextInstance(data)) != null) {
            instanceNumber += 1;

            double classification1 = learner1.classifyInstance(current);
            double classification2 = learner2.classifyInstance(current);
            double classification3 = learner3.classifyInstance(current);
            double classification5 = learner5.classifyInstance(current);

            // All classifiers must agree. This is a very primitive ensemble strategy!
            if (classification1 == 1 && classification2 == 1 && classification3 == 1 && classification5 == 1) {
                Writer.append(resultPath, instanceNumber + "\n");
            }
        }

        System.out.println("Test set instances: " + instanceNumber);

        long endTime = System.nanoTime();
        long duration = endTime - startTime;
        double seconds = (double) duration / 1000000000.0;

        System.out.println("Testing Ensemble completed in " + duration + " (ns) or " + seconds + " (s).");
    } catch (Exception e) {
        System.out.println("Could not test Ensemble classifier due to an error.");
    }
}

From source file:cs.man.ac.uk.predict.Predictor.java

License:Open Source License

public static void makePredictionsJ48(String trainPath, String testPath, String resultPath) {
    /**//  w w  w  .  j a  v a2  s .co  m
     * The decision tree classifier.
     */
    J48 learner = new J48();

    System.out.println("Training set: " + trainPath);
    System.out.println("Test set: " + testPath);

    System.out.println("Training J48");
    long startTime = System.nanoTime();
    try {
        BufferedReader reader = new BufferedReader(new FileReader(trainPath));
        Instances data = new Instances(reader);
        data.setClassIndex(data.numAttributes() - 1);
        System.out.println("Training data length: " + data.numInstances());
        learner.buildClassifier(data);

        long endTime = System.nanoTime();
        long nanoseconds = endTime - startTime;
        double seconds = (double) nanoseconds / 1000000000.0;
        System.out.println("Training J48 completed in " + nanoseconds + " (ns) or " + seconds + " (s)");
    } catch (IOException e) {
        System.out.println("Could not train J48 classifier IOException on training data file");
    } catch (Exception e) {
        System.out.println("Could not train J48 classifier Exception building model");
    }

    try {
        // Prepare data for testing
        //BufferedReader reader = new BufferedReader( new FileReader(testPath));
        //Instances data = new Instances(reader);
        //data.setClassIndex(data.numAttributes() - 1);

        ArffLoader loader = new ArffLoader();
        loader.setFile(new File(testPath));
        Instances data = loader.getStructure();
        data.setClassIndex(data.numAttributes() - 1);

        System.out.println("J48 Classifier is ready.");
        System.out.println("Testing on all instances avaialable.");
        System.out.println("Test set instances: " + data.numInstances());

        startTime = System.nanoTime();

        int instanceNumber = 0;

        // label instances
        Instance current;

        //for (int i = 0; i < data.numInstances(); i++) 
        while ((current = loader.getNextInstance(data)) != null) {
            instanceNumber += 1;

            //double classification = learner.classifyInstance(data.instance(i));
            double classification = learner.classifyInstance(current);
            //String instanceClass= Double.toString(data.instance(i).classValue());

            if (classification == 1)// Predicted positive, actually negative
            {
                Writer.append(resultPath, instanceNumber + "\n");
            }
        }

        long endTime = System.nanoTime();
        long duration = endTime - startTime;
        double seconds = (double) duration / 1000000000.0;

        System.out.println("Testing J48 completed in " + duration + " (ns) or " + seconds + " (s)");
    } catch (Exception e) {
        System.out.println("Could not test J48 classifier due to an error");
    }
}

From source file:cz.vse.fis.keg.entityclassifier.core.salience.EntitySaliencer.java

License:Open Source License

public void computeSalience(List<Entity> entities) {
    try {//from   www  .j  ava  2  s.  c  om
        if (!initialized) {
            initialize();
            initialized = true;
        }

        ArrayList<SEntity> processedEntities = new ArrayList<SEntity>();

        for (Entity e : entities) {
            SEntity entityMention = new SEntity();
            entityMention.setBeginIndex(e.getStartOffset().intValue());
            entityMention.setEntityType(e.getEntityType());

            ArrayList<Type> types = e.getTypes();
            ArrayList<String> loggedURIs = new ArrayList<String>();

            if (types != null) {
                for (Type t : types) {
                    String entityURI = t.getEntityURI();

                    if (!loggedURIs.contains(entityURI)) {
                        loggedURIs.add(entityURI);
                        entityMention.getUrls().add(entityURI);
                    }
                }
            }

            boolean entityAlreadyLogged = false;

            for (SEntity sEntity : processedEntities) {
                boolean isThisEntitySame = false;
                ArrayList<String> entityURIs1 = sEntity.getUrls();
                ArrayList<String> entityURIs2 = entityMention.getUrls();

                for (String eURI1 : entityURIs1) {
                    for (String eURI2 : entityURIs2) {
                        if (!entityAlreadyLogged) {
                            if (eURI1.equals(eURI2)) {
                                entityAlreadyLogged = true;
                                isThisEntitySame = true;
                                sEntity.setNumOccurrences(sEntity.getNumOccurrences() + 1);
                            }
                        }
                    }
                }

                if (isThisEntitySame) {
                    for (String uri : entityMention.getUrls()) {
                        if (!sEntity.getUrls().contains(uri)) {
                            sEntity.getUrls().add(uri);
                        }
                    }
                }
            }

            // Entity seen for first time in the document.
            if (!entityAlreadyLogged) {
                entityMention.setNumOccurrences(1);
                processedEntities.add(entityMention);
            }
        }

        // Preparing the test data container.
        FastVector attributes = new FastVector(6);
        attributes.add(new Attribute("beginIndex"));
        attributes.add(new Attribute("numUniqueEntitiesInDoc"));
        attributes.add(new Attribute("numOfOccurrencesOfEntityInDoc"));
        attributes.add(new Attribute("numOfEntityMentionsInDoc"));

        FastVector entityTypeNominalAttVal = new FastVector(2);
        entityTypeNominalAttVal.addElement("named_entity");
        entityTypeNominalAttVal.addElement("common_entity");

        Attribute entityTypeAtt = new Attribute("type", entityTypeNominalAttVal);
        attributes.add(entityTypeAtt);
        FastVector classNominalAttVal = new FastVector(3);
        classNominalAttVal.addElement("not_salient");
        classNominalAttVal.addElement("less_salient");
        classNominalAttVal.addElement("most_salient");
        Attribute classAtt = new Attribute("class", classNominalAttVal);
        attributes.add(classAtt);
        Instances evalData = new Instances("MyRelation", attributes, 0);

        evalData.setClassIndex(evalData.numAttributes() - 1);

        for (int i = 0; i < processedEntities.size(); i++) {

            String entityType = "";
            if (processedEntities.get(i).getEntityType().equals("named entity")) {
                entityType = "named_entity";
            } else if (processedEntities.get(i).getEntityType().equals("common entity")) {
                entityType = "common_entity";
            } else {
            }
            Instance inst = new DenseInstance(6);
            inst.setValue(evalData.attribute(0), processedEntities.get(i).getBeginIndex()); // begin index
            inst.setValue(evalData.attribute(1), processedEntities.size()); // num of unique entities in doc
            inst.setValue(evalData.attribute(2), processedEntities.get(i).getNumOccurrences()); // num of entity occurrences in doc
            inst.setValue(evalData.attribute(3), entities.size()); // num of entity mentions in doc
            inst.setValue(evalData.attribute(4), entityType); // type of the entity
            evalData.add(inst);

        }

        for (int i = 0; i < processedEntities.size(); i++) {
            SEntity sEntity = processedEntities.get(i);
            int classIndex = (int) classifier.classifyInstance(evalData.get(i));
            String classLabel = evalData.firstInstance().classAttribute().value(classIndex);
            double pred[] = classifier.distributionForInstance(evalData.get(i));
            double probability = pred[classIndex];

            double salienceScore = pred[1] * 0.5 + pred[2];
            sEntity.setSalienceScore(salienceScore);
            sEntity.setSalienceConfidence(probability);
            sEntity.setSalienceClass(classLabel);

            for (Entity e : entities) {
                ArrayList<Type> types = e.getTypes();
                if (types != null) {
                    for (Type t : types) {
                        if (sEntity.getUrls().contains(t.getEntityURI())) {
                            Salience s = new Salience();
                            s.setClassLabel(classLabel);
                            DecimalFormat df = new DecimalFormat("0.000");
                            double fProbability = df.parse(df.format(probability)).doubleValue();
                            double fSalience = df.parse(df.format(salienceScore)).doubleValue();
                            s.setConfidence(fProbability);
                            s.setScore(fSalience);
                            t.setSalience(s);
                        }
                    }
                }
            }
        }

    } catch (Exception ex) {
        Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
    }
}

From source file:cz.vse.fis.keg.entityclassifier.core.salience.EntitySaliencer.java

License:Open Source License

private void trainModel() {

    BufferedReader reader = null;

    try {/* ww  w . j av  a  2s . co  m*/

        URL fileURL = THDController.getInstance().getClass().getResource(Settings.SALIENCE_DATASET);
        File arrfFile = new File(fileURL.getFile());

        reader = new BufferedReader(new FileReader(arrfFile));
        Instances data = new Instances(reader);
        data.setClassIndex(data.numAttributes() - 1);

        //            classifier = new NaiveBayes();
        classifier = new RandomForest();

        // Train the classifer.
        classifier.buildClassifier(data);

    } catch (FileNotFoundException ex) {
        Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
    } catch (IOException ex) {
        Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
    } catch (Exception ex) {
        Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
    } finally {
        try {
            reader.close();
            System.out.println("Model was successfully trained.");
        } catch (IOException ex) {
            Logger.getLogger(EntitySaliencer.class.getName()).log(Level.SEVERE, null, ex);
        }
    }
}

From source file:data.generation.target.utils.PrincipalComponents.java

License:Open Source License

/**
 * Return a summary of the analysis//  w ww .j  a  va 2s.  c  o m
 * @return a summary of the analysis.
 */
private String principalComponentsSummary() {
    StringBuffer result = new StringBuffer();
    double cumulative = 0.0;
    Instances output = null;
    int numVectors = 0;

    try {
        output = setOutputFormat();
        numVectors = (output.classIndex() < 0) ? output.numAttributes() : output.numAttributes() - 1;
    } catch (Exception ex) {
    }
    //tomorrow
    String corrCov = (m_center) ? "Covariance " : "Correlation ";
    result.append(corrCov + "matrix\n" + matrixToString(m_correlation) + "\n\n");
    result.append("eigenvalue\tproportion\tcumulative\n");
    for (int i = m_numAttribs - 1; i > (m_numAttribs - numVectors - 1); i--) {
        cumulative += m_eigenvalues[m_sortedEigens[i]];
        result.append(Utils.doubleToString(m_eigenvalues[m_sortedEigens[i]], 9, 5) + "\t"
                + Utils.doubleToString((m_eigenvalues[m_sortedEigens[i]] / m_sumOfEigenValues), 9, 5) + "\t"
                + Utils.doubleToString((cumulative / m_sumOfEigenValues), 9, 5) + "\t"
                + output.attribute(m_numAttribs - i - 1).name() + "\n");
    }

    result.append("\nEigenvectors\n");
    for (int j = 1; j <= numVectors; j++) {
        result.append(" V" + j + '\t');
    }
    result.append("\n");
    for (int j = 0; j < m_numAttribs; j++) {

        for (int i = m_numAttribs - 1; i > (m_numAttribs - numVectors - 1); i--) {
            result.append(Utils.doubleToString(m_eigenvectors[j][m_sortedEigens[i]], 7, 4) + "\t");
        }
        result.append(m_trainInstances.attribute(j).name() + '\n');
    }

    if (m_transBackToOriginal) {
        result.append("\nPC space transformed back to original space.\n"
                + "(Note: can't evaluate attributes in the original " + "space)\n");
    }
    return result.toString();
}

From source file:data.generation.target.utils.PrincipalComponents.java

License:Open Source License

/**
 * Set up the header for the PC->original space dataset
 * /*from w w  w.  j a va 2 s.co  m*/
 * @return            the output format
 * @throws Exception  if something goes wrong
 */
private Instances setOutputFormatOriginal() throws Exception {
    FastVector attributes = new FastVector();

    for (int i = 0; i < m_numAttribs; i++) {
        String att = m_trainInstances.attribute(i).name();
        attributes.addElement(new Attribute(att));
    }

    if (m_hasClass) {
        attributes.addElement(m_trainHeader.classAttribute().copy());
    }

    Instances outputFormat = new Instances(m_trainHeader.relationName() + "->PC->original space", attributes,
            0);

    // set the class to be the last attribute if necessary
    if (m_hasClass) {
        outputFormat.setClassIndex(outputFormat.numAttributes() - 1);
    }

    return outputFormat;
}