List of usage examples for weka.core Instances randomize
public void randomize(Random random)
From source file:com.reactivetechnologies.analytics.core.eval.StackingWithBuiltClassifiers.java
License:Open Source License
/** * Buildclassifier selects a classifier from the set of classifiers * by minimising error on the training data. * * @param data the training data to be used for generating the * boosted classifier./*from w w w. jav a2 s . c om*/ * @throws Exception if the classifier could not be built successfully */ @Override public void buildClassifier(Instances data) throws Exception { if (m_MetaClassifier == null) { throw new IllegalArgumentException("No meta classifier has been set"); } // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class Instances newData = new Instances(data); m_BaseFormat = new Instances(data, 0); newData.deleteWithMissingClass(); Random random = new Random(m_Seed); newData.randomize(random); if (newData.classAttribute().isNominal()) { newData.stratify(m_NumFolds); } // Create meta level generateMetaLevel(newData, random); /** Changed here */ // DO NOT Rebuilt all the base classifiers on the full training data /*for (int i = 0; i < m_Classifiers.length; i++) { getClassifier(i).buildClassifier(newData); }*/ /** End change */ }
From source file:core.classifier.MyFirstClassifier.java
License:Open Source License
/** * Method for building the classifier. Implements a one-against-one * wrapper for multi-class problems./*from w ww . j a v a2 s.co m*/ * * @param insts the set of training instances * @throws Exception if the classifier can't be built successfully */ public void buildClassifier(Instances insts) throws Exception { if (!m_checksTurnedOff) { // can classifier handle the data? getCapabilities().testWithFail(insts); // remove instances with missing class insts = new Instances(insts); insts.deleteWithMissingClass(); /* Removes all the instances with weight equal to 0. MUST be done since condition (8) of Keerthi's paper is made with the assertion Ci > 0 (See equation (3a). */ Instances data = new Instances(insts, insts.numInstances()); for (int i = 0; i < insts.numInstances(); i++) { if (insts.instance(i).weight() > 0) data.add(insts.instance(i)); } if (data.numInstances() == 0) { throw new Exception("No training instances left after removing " + "instances with weight 0!"); } insts = data; } if (!m_checksTurnedOff) { m_Missing = new ReplaceMissingValues(); m_Missing.setInputFormat(insts); insts = Filter.useFilter(insts, m_Missing); } else { m_Missing = null; } if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) { boolean onlyNumeric = true; if (!m_checksTurnedOff) { for (int i = 0; i < insts.numAttributes(); i++) { if (i != insts.classIndex()) { if (!insts.attribute(i).isNumeric()) { onlyNumeric = false; break; } } } } if (!onlyNumeric) { m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(insts); insts = Filter.useFilter(insts, m_NominalToBinary); } else { m_NominalToBinary = null; } } else { m_NominalToBinary = null; } if (m_filterType == FILTER_STANDARDIZE) { m_Filter = new Standardize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else if (m_filterType == FILTER_NORMALIZE) { m_Filter = new Normalize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else { m_Filter = null; } m_classIndex = insts.classIndex(); m_classAttribute = insts.classAttribute(); m_KernelIsLinear = (m_kernel instanceof PolyKernel) && (((PolyKernel) m_kernel).getExponent() == 1.0); // Generate subsets representing each class Instances[] subsets = new Instances[insts.numClasses()]; for (int i = 0; i < insts.numClasses(); i++) { subsets[i] = new Instances(insts, insts.numInstances()); } for (int j = 0; j < insts.numInstances(); j++) { Instance inst = insts.instance(j); subsets[(int) inst.classValue()].add(inst); } for (int i = 0; i < insts.numClasses(); i++) { subsets[i].compactify(); } // Build the binary classifiers Random rand = new Random(m_randomSeed); m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()]; for (int i = 0; i < insts.numClasses(); i++) { for (int j = i + 1; j < insts.numClasses(); j++) { m_classifiers[i][j] = new BinarySMO(); m_classifiers[i][j].setKernel(Kernel.makeCopy(getKernel())); Instances data = new Instances(insts, insts.numInstances()); for (int k = 0; k < subsets[i].numInstances(); k++) { data.add(subsets[i].instance(k)); } for (int k = 0; k < subsets[j].numInstances(); k++) { data.add(subsets[j].instance(k)); } data.compactify(); data.randomize(rand); m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed); } } }
From source file:core.ClusterEvaluationEX.java
License:Open Source License
/** * Evaluates a clusterer with the options given in an array of * strings. It takes the string indicated by "-t" as training file, the * string indicated by "-T" as test file. * If the test file is missing, a stratified ten-fold * cross-validation is performed (distribution clusterers only). * Using "-x" you can change the number of * folds to be used, and using "-s" the random seed. * If the "-p" option is present it outputs the classification for * each test instance. If you provide the name of an object file using * "-l", a clusterer will be loaded from the given file. If you provide the * name of an object file using "-d", the clusterer built from the * training data will be saved to the given file. * * @param clusterer machine learning clusterer * @param options the array of string containing the options * @throws Exception if model could not be evaluated successfully * @return a string describing the results *//*w ww.j a v a2 s . c o m*/ public static String evaluateClusterer(Clusterer clusterer, String[] options) throws Exception { int seed = 1, folds = 10; boolean doXval = false; Instances train = null; Random random; String trainFileName, testFileName, seedString, foldsString; String objectInputFileName, objectOutputFileName, attributeRangeString; String graphFileName; String[] savedOptions = null; boolean printClusterAssignments = false; Range attributesToOutput = null; StringBuffer text = new StringBuffer(); int theClass = -1; // class based evaluation of clustering boolean updateable = (clusterer instanceof UpdateableClusterer); DataSource source = null; Instance inst; if (Utils.getFlag('h', options) || Utils.getFlag("help", options)) { // global info requested as well? boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options); throw new Exception("Help requested." + makeOptionString(clusterer, globalInfo)); } try { // Get basic options (options the same for all clusterers //printClusterAssignments = Utils.getFlag('p', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); trainFileName = Utils.getOption('t', options); testFileName = Utils.getOption('T', options); graphFileName = Utils.getOption('g', options); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClusterAssignments = true; if (!attributeRangeString.equals("0")) attributesToOutput = new Range(attributeRangeString); } if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test file given."); } } else { if ((objectInputFileName.length() != 0) && (printClusterAssignments == false)) { throw new Exception("Can't use both train and model file " + "unless -p specified."); } } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); doXval = true; } } catch (Exception e) { throw new Exception('\n' + e.getMessage() + makeOptionString(clusterer, false)); } try { if (trainFileName.length() != 0) { source = new DataSource(trainFileName); train = source.getStructure(); String classString = Utils.getOption('c', options); if (classString.length() != 0) { if (classString.compareTo("last") == 0) theClass = train.numAttributes(); else if (classString.compareTo("first") == 0) theClass = 1; else theClass = Integer.parseInt(classString); if (theClass != -1) { if (doXval || testFileName.length() != 0) throw new Exception("Can only do class based evaluation on the " + "training data"); if (objectInputFileName.length() != 0) throw new Exception("Can't load a clusterer and do class based " + "evaluation"); if (objectOutputFileName.length() != 0) throw new Exception("Can't do class based evaluation and save clusterer"); } } else { // if the dataset defines a class attribute, use it if (train.classIndex() != -1) { theClass = train.classIndex() + 1; System.err .println("Note: using class attribute from dataset, i.e., attribute #" + theClass); } } if (theClass != -1) { if (theClass < 1 || theClass > train.numAttributes()) throw new Exception("Class is out of range!"); if (!train.attribute(theClass - 1).isNominal()) throw new Exception("Class must be nominal!"); train.setClassIndex(theClass - 1); } } } catch (Exception e) { throw new Exception("ClusterEvaluation: " + e.getMessage() + '.'); } // Save options if (options != null) { savedOptions = new String[options.length]; System.arraycopy(options, 0, savedOptions, 0, options.length); } if (objectInputFileName.length() != 0) Utils.checkForRemainingOptions(options); // Set options for clusterer if (clusterer instanceof OptionHandler) ((OptionHandler) clusterer).setOptions(options); Utils.checkForRemainingOptions(options); Instances trainHeader = train; if (objectInputFileName.length() != 0) { // Load the clusterer from file // clusterer = (Clusterer) SerializationHelper.read(objectInputFileName); java.io.ObjectInputStream ois = new java.io.ObjectInputStream( new java.io.BufferedInputStream(new java.io.FileInputStream(objectInputFileName))); clusterer = (Clusterer) ois.readObject(); // try and get the training header try { trainHeader = (Instances) ois.readObject(); } catch (Exception ex) { // don't moan if we cant } } else { // Build the clusterer if no object file provided if (theClass == -1) { if (updateable) { clusterer.buildClusterer(source.getStructure()); while (source.hasMoreElements(train)) { inst = source.nextElement(train); ((UpdateableClusterer) clusterer).updateClusterer(inst); } ((UpdateableClusterer) clusterer).updateFinished(); } else { clusterer.buildClusterer(source.getDataSet()); } } else { Remove removeClass = new Remove(); removeClass.setAttributeIndices("" + theClass); removeClass.setInvertSelection(false); removeClass.setInputFormat(train); if (updateable) { Instances clusterTrain = Filter.useFilter(train, removeClass); clusterer.buildClusterer(clusterTrain); trainHeader = clusterTrain; while (source.hasMoreElements(train)) { inst = source.nextElement(train); removeClass.input(inst); removeClass.batchFinished(); Instance clusterTrainInst = removeClass.output(); ((UpdateableClusterer) clusterer).updateClusterer(clusterTrainInst); } ((UpdateableClusterer) clusterer).updateFinished(); } else { Instances clusterTrain = Filter.useFilter(source.getDataSet(), removeClass); clusterer.buildClusterer(clusterTrain); trainHeader = clusterTrain; } ClusterEvaluationEX ce = new ClusterEvaluationEX(); ce.setClusterer(clusterer); ce.evaluateClusterer(train, trainFileName); return "\n\n=== Clustering stats for training data ===\n\n" + ce.clusterResultsToString(); } } /* Output cluster predictions only (for the test data if specified, otherwise for the training data */ if (printClusterAssignments) { return printClusterings(clusterer, trainFileName, testFileName, attributesToOutput); } text.append(clusterer.toString()); text.append( "\n\n=== Clustering stats for training data ===\n\n" + printClusterStats(clusterer, trainFileName)); if (testFileName.length() != 0) { // check header compatibility DataSource test = new DataSource(testFileName); Instances testStructure = test.getStructure(); if (!trainHeader.equalHeaders(testStructure)) { throw new Exception("Training and testing data are not compatible\n"); } text.append("\n\n=== Clustering stats for testing data ===\n\n" + printClusterStats(clusterer, testFileName)); } if ((clusterer instanceof DensityBasedClusterer) && (doXval == true) && (testFileName.length() == 0) && (objectInputFileName.length() == 0)) { // cross validate the log likelihood on the training data random = new Random(seed); random.setSeed(seed); train = source.getDataSet(); train.randomize(random); text.append(crossValidateModel(clusterer.getClass().getName(), train, folds, savedOptions, random)); } // Save the clusterer if an object output file is provided if (objectOutputFileName.length() != 0) { //SerializationHelper.write(objectOutputFileName, clusterer); saveClusterer(objectOutputFileName, clusterer, trainHeader); } // If classifier is drawable output string describing graph if ((clusterer instanceof Drawable) && (graphFileName.length() != 0)) { BufferedWriter writer = new BufferedWriter(new FileWriter(graphFileName)); writer.write(((Drawable) clusterer).graph()); writer.newLine(); writer.flush(); writer.close(); } return text.toString(); }
From source file:core.ClusterEvaluationEX.java
License:Open Source License
/** * Perform a cross-validation for DensityBasedClusterer on a set of instances. * * @param clusterer the clusterer to use * @param data the training data/*from ww w . j a v a2 s . c om*/ * @param numFolds number of folds of cross validation to perform * @param random random number seed for cross-validation * @return the cross-validated log-likelihood * @throws Exception if an error occurs */ public static double crossValidateModel(DensityBasedClusterer clusterer, Instances data, int numFolds, Random random) throws Exception { Instances train, test; double foldAv = 0; ; data = new Instances(data); data.randomize(random); // double sumOW = 0; for (int i = 0; i < numFolds; i++) { // Build and test clusterer train = data.trainCV(numFolds, i, random); clusterer.buildClusterer(train); test = data.testCV(numFolds, i); for (int j = 0; j < test.numInstances(); j++) { try { foldAv += ((DensityBasedClusterer) clusterer).logDensityForInstance(test.instance(j)); // sumOW += test.instance(j).weight(); // double temp = Utils.sum(tempDist); } catch (Exception ex) { // unclustered instances } } } // return foldAv / sumOW; return foldAv / data.numInstances(); }
From source file:cotraining.copy.Evaluation_D.java
License:Open Source License
/** * Performs a (stratified if class is nominal) cross-validation * for a classifier on a set of instances. Now performs * a deep copy of the classifier before each call to * buildClassifier() (just in case the classifier is not * initialized properly)./*from w w w . ja v a 2 s .c o m*/ * * @param classifier the classifier with any options set. * @param data the data on which the cross-validation is to be * performed * @param numFolds the number of folds for the cross-validation * @param random random number generator for randomization * @param forPredictionsString varargs parameter that, if supplied, is * expected to hold a StringBuffer to print predictions to, * a Range of attributes to output and a Boolean (true if the distribution * is to be printed) * @throws Exception if a classifier could not be generated * successfully or the class is not defined */ public void crossValidateModel(Classifier classifier, Instances data, int numFolds, Random random, Object... forPredictionsPrinting) throws Exception { // Make a copy of the data we can reorder data = new Instances(data); data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(numFolds); } // We assume that the first element is a StringBuffer, the second a Range (attributes // to output) and the third a Boolean (whether or not to output a distribution instead // of just a classification) if (forPredictionsPrinting.length > 0) { // print the header first StringBuffer buff = (StringBuffer) forPredictionsPrinting[0]; Range attsToOutput = (Range) forPredictionsPrinting[1]; boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue(); printClassificationsHeader(data, attsToOutput, printDist, buff); } // Do the folds for (int i = 0; i < numFolds; i++) { Instances train = data.trainCV(numFolds, i, random); setPriors(train); Classifier copiedClassifier = Classifier.makeCopy(classifier); copiedClassifier.buildClassifier(train); Instances test = data.testCV(numFolds, i); evaluateModel(copiedClassifier, test, forPredictionsPrinting); } m_NumFolds = numFolds; }
From source file:cotraining.copy.Evaluation_D.java
License:Open Source License
/** * Evaluates a classifier with the options given in an array of * strings. <p/>//from w ww .j a v a 2 s .c o m * * Valid options are: <p/> * * -t name of training file <br/> * Name of the file with the training data. (required) <p/> * * -T name of test file <br/> * Name of the file with the test data. If missing a cross-validation * is performed. <p/> * * -c class index <br/> * Index of the class attribute (1, 2, ...; default: last). <p/> * * -x number of folds <br/> * The number of folds for the cross-validation (default: 10). <p/> * * -no-cv <br/> * No cross validation. If no test file is provided, no evaluation * is done. <p/> * * -split-percentage percentage <br/> * Sets the percentage for the train/test set split, e.g., 66. <p/> * * -preserve-order <br/> * Preserves the order in the percentage split instead of randomizing * the data first with the seed value ('-s'). <p/> * * -s seed <br/> * Random number seed for the cross-validation and percentage split * (default: 1). <p/> * * -m file with cost matrix <br/> * The name of a file containing a cost matrix. <p/> * * -l filename <br/> * Loads classifier from the given file. In case the filename ends with * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/> * * -d filename <br/> * Saves classifier built from the training data into the given file. In case * the filename ends with ".xml" the options are saved XML, not the model. <p/> * * -v <br/> * Outputs no statistics for the training data. <p/> * * -o <br/> * Outputs statistics only, not the classifier. <p/> * * -i <br/> * Outputs detailed information-retrieval statistics per class. <p/> * * -k <br/> * Outputs information-theoretic statistics. <p/> * * -p range <br/> * Outputs predictions for test instances (or the train instances if no test * instances provided and -no-cv is used), along with the attributes in the specified range * (and nothing else). Use '-p 0' if no attributes are desired. <p/> * * -distribution <br/> * Outputs the distribution instead of only the prediction * in conjunction with the '-p' option (only nominal classes). <p/> * * -r <br/> * Outputs cumulative margin distribution (and nothing else). <p/> * * -g <br/> * Only for classifiers that implement "Graphable." Outputs * the graph representation of the classifier (and nothing * else). <p/> * * -xml filename | xml-string <br/> * Retrieves the options from the XML-data instead of the command line. <p/> * * @param classifier machine learning classifier * @param options the array of string containing the options * @throws Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateModel(Classifier classifier, String[] options) throws Exception { Instances train = null, tempTrain, test = null, template = null; int seed = 1, folds = 10, classIndex = -1; boolean noCrossValidation = false; String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString; boolean noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false; StringBuffer text = new StringBuffer(); DataSource trainSource = null, testSource = null; ObjectInputStream objectInputStream = null; BufferedInputStream xmlInputStream = null; CostMatrix costMatrix = null; StringBuffer schemeOptionsText = null; Range attributesToOutput = null; long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0; String xml = ""; String[] optionsTmp = null; Classifier classifierBackup; Classifier classifierClassifications = null; boolean printDistribution = false; int actualClassIndex = -1; // 0-based class index String splitPercentageString = ""; int splitPercentage = -1; boolean preserveOrder = false; boolean trainSetPresent = false; boolean testSetPresent = false; String thresholdFile; String thresholdLabel; StringBuffer predsBuff = null; // predictions from cross-validation // help requested? if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) { // global info requested as well? boolean globalInfo = Utils.getFlag("synopsis", options) || Utils.getFlag("info", options); throw new Exception("\nHelp requested." + makeOptionString(classifier, globalInfo)); } try { // do we get the input from XML instead of normal parameters? xml = Utils.getOption("xml", options); if (!xml.equals("")) options = new XMLOptions(xml).toArray(); // is the input model only the XML-Options, i.e. w/o built model? optionsTmp = new String[options.length]; for (int i = 0; i < options.length; i++) optionsTmp[i] = options[i]; String tmpO = Utils.getOption('l', optionsTmp); //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) { if (tmpO.endsWith(".xml")) { // try to load file as PMML first boolean success = false; try { //PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO); //if (pmmlModel instanceof PMMLClassifier) { //classifier = ((PMMLClassifier)pmmlModel); // success = true; //} } catch (IllegalArgumentException ex) { success = false; } if (!success) { // load options from serialized data ('-l' is automatically erased!) XMLClassifier xmlserial = new XMLClassifier(); Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options)); // merge options optionsTmp = new String[options.length + cl.getOptions().length]; System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length); System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length); options = optionsTmp; } } noCrossValidation = Utils.getFlag("no-cv", options); // Get basic options (options the same for all schemes) classIndexString = Utils.getOption('c', options); if (classIndexString.length() != 0) { if (classIndexString.equals("first")) classIndex = 1; else if (classIndexString.equals("last")) classIndex = -1; else classIndex = Integer.parseInt(classIndexString); } trainFileName = Utils.getOption('t', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); testFileName = Utils.getOption('T', options); foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test " + "file given."); } } else if ((objectInputFileName.length() != 0) && ((!(classifier instanceof UpdateableClassifier)) || (testFileName.length() == 0))) { throw new Exception("Classifier not incremental, or no " + "test file provided: can't " + "use both train and model file."); } try { if (trainFileName.length() != 0) { trainSetPresent = true; trainSource = new DataSource(trainFileName); } if (testFileName.length() != 0) { testSetPresent = true; testSource = new DataSource(testFileName); } if (objectInputFileName.length() != 0) { if (objectInputFileName.endsWith(".xml")) { // if this is the case then it means that a PMML classifier was // successfully loaded earlier in the code objectInputStream = null; xmlInputStream = null; } else { InputStream is = new FileInputStream(objectInputFileName); if (objectInputFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } // load from KOML? if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent())) { objectInputStream = new ObjectInputStream(is); xmlInputStream = null; } else { objectInputStream = null; xmlInputStream = new BufferedInputStream(is); } } } } catch (Exception e) { throw new Exception("Can't open file " + e.getMessage() + '.'); } if (testSetPresent) { template = test = testSource.getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if ((test.classIndex() == -1) || (classIndexString.length() != 0)) test.setClassIndex(test.numAttributes() - 1); } actualClassIndex = test.classIndex(); } else { // percentage split splitPercentageString = Utils.getOption("split-percentage", options); if (splitPercentageString.length() != 0) { if (foldsString.length() != 0) throw new Exception("Percentage split cannot be used in conjunction with " + "cross-validation ('-x')."); splitPercentage = Integer.parseInt(splitPercentageString); if ((splitPercentage <= 0) || (splitPercentage >= 100)) throw new Exception("Percentage split value needs be >0 and <100."); } else { splitPercentage = -1; } preserveOrder = Utils.getFlag("preserve-order", options); if (preserveOrder) { if (splitPercentage == -1) throw new Exception("Percentage split ('-percentage-split') is missing."); } // create new train/test sources if (splitPercentage > 0) { testSetPresent = true; Instances tmpInst = trainSource.getDataSet(actualClassIndex); if (!preserveOrder) tmpInst.randomize(new Random(seed)); int trainSize = tmpInst.numInstances() * splitPercentage / 100; int testSize = tmpInst.numInstances() - trainSize; Instances trainInst = new Instances(tmpInst, 0, trainSize); Instances testInst = new Instances(tmpInst, trainSize, testSize); trainSource = new DataSource(trainInst); testSource = new DataSource(testInst); template = test = testSource.getStructure(); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { if ((test.classIndex() == -1) || (classIndexString.length() != 0)) test.setClassIndex(test.numAttributes() - 1); } actualClassIndex = test.classIndex(); } } if (trainSetPresent) { template = train = trainSource.getStructure(); if (classIndex != -1) { train.setClassIndex(classIndex - 1); } else { if ((train.classIndex() == -1) || (classIndexString.length() != 0)) train.setClassIndex(train.numAttributes() - 1); } actualClassIndex = train.classIndex(); if ((testSetPresent) && !test.equalHeaders(train)) { throw new IllegalArgumentException("Train and test file not compatible!"); } } if (template == null) { throw new Exception("No actual dataset provided to use as template"); } costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses()); classStatistics = Utils.getFlag('i', options); noOutput = Utils.getFlag('o', options); trainStatistics = !Utils.getFlag('v', options); printComplexityStatistics = Utils.getFlag('k', options); printMargins = Utils.getFlag('r', options); printGraph = Utils.getFlag('g', options); sourceClass = Utils.getOption('z', options); printSource = (sourceClass.length() != 0); printDistribution = Utils.getFlag("distribution", options); thresholdFile = Utils.getOption("threshold-file", options); thresholdLabel = Utils.getOption("threshold-label", options); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClassifications = true; noOutput = true; if (!attributeRangeString.equals("0")) attributesToOutput = new Range(attributeRangeString); } if (!printClassifications && printDistribution) throw new Exception("Cannot print distribution without '-p' option!"); // if no training file given, we don't have any priors if ((!trainSetPresent) && (printComplexityStatistics)) throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!"); // If a model file is given, we can't process // scheme-specific options if (objectInputFileName.length() != 0) { Utils.checkForRemainingOptions(options); } else { // Set options for classifier if (classifier instanceof OptionHandler) { for (int i = 0; i < options.length; i++) { if (options[i].length() != 0) { if (schemeOptionsText == null) { schemeOptionsText = new StringBuffer(); } if (options[i].indexOf(' ') != -1) { schemeOptionsText.append('"' + options[i] + "\" "); } else { schemeOptionsText.append(options[i] + " "); } } } ((OptionHandler) classifier).setOptions(options); } } Utils.checkForRemainingOptions(options); } catch (Exception e) { throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier, false)); } // Setup up evaluation objects Evaluation_D trainingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix); Evaluation_D testingEvaluation = new Evaluation_D(new Instances(template, 0), costMatrix); // disable use of priors if no training file given if (!trainSetPresent) testingEvaluation.useNoPriors(); if (objectInputFileName.length() != 0) { // Load classifier from file if (objectInputStream != null) { classifier = (Classifier) objectInputStream.readObject(); // try and read a header (if present) Instances savedStructure = null; try { savedStructure = (Instances) objectInputStream.readObject(); } catch (Exception ex) { // don't make a fuss } if (savedStructure != null) { // test for compatibility with template if (!template.equalHeaders(savedStructure)) { throw new Exception("training and test set are not compatible"); } } objectInputStream.close(); } else if (xmlInputStream != null) { // whether KOML is available has already been checked (objectInputStream would null otherwise)! classifier = (Classifier) KOML.read(xmlInputStream); xmlInputStream.close(); } } // backup of fully setup classifier for cross-validation classifierBackup = Classifier.makeCopy(classifier); // Build the classifier if no object file provided if ((classifier instanceof UpdateableClassifier) && (testSetPresent || noCrossValidation) && (costMatrix == null) && (trainSetPresent)) { // Build classifier incrementally trainingEvaluation.setPriors(train); testingEvaluation.setPriors(train); trainTimeStart = System.currentTimeMillis(); if (objectInputFileName.length() == 0) { classifier.buildClassifier(train); } Instance trainInst; while (trainSource.hasMoreElements(train)) { trainInst = trainSource.nextElement(train); trainingEvaluation.updatePriors(trainInst); testingEvaluation.updatePriors(trainInst); ((UpdateableClassifier) classifier).updateClassifier(trainInst); } trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } else if (objectInputFileName.length() == 0) { // Build classifier in one go tempTrain = trainSource.getDataSet(actualClassIndex); trainingEvaluation.setPriors(tempTrain); testingEvaluation.setPriors(tempTrain); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(tempTrain); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } // backup of fully trained classifier for printing the classifications if (printClassifications) classifierClassifications = Classifier.makeCopy(classifier); // Save the classifier if an object output file is provided if (objectOutputFileName.length() != 0) { OutputStream os = new FileOutputStream(objectOutputFileName); // binary if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) { if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); if (template != null) { objectOutputStream.writeObject(template); } objectOutputStream.flush(); objectOutputStream.close(); } // KOML/XML else { BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os); if (objectOutputFileName.endsWith(".xml")) { XMLSerialization xmlSerial = new XMLClassifier(); xmlSerial.write(xmlOutputStream, classifier); } else // whether KOML is present has already been checked // if not present -> ".koml" is interpreted as binary - see above if (objectOutputFileName.endsWith(".koml")) { KOML.write(xmlOutputStream, classifier); } xmlOutputStream.close(); } } // If classifier is drawable output string describing graph if ((classifier instanceof Drawable) && (printGraph)) { return ((Drawable) classifier).graph(); } // Output the classifier as equivalent source if ((classifier instanceof Sourcable) && (printSource)) { return wekaStaticWrapper((Sourcable) classifier, sourceClass); } // Output model if (!(noOutput || printMargins)) { if (classifier instanceof OptionHandler) { if (schemeOptionsText != null) { text.append("\nOptions: " + schemeOptionsText); text.append("\n"); } } text.append("\n" + classifier.toString() + "\n"); } if (!printMargins && (costMatrix != null)) { text.append("\n=== Evaluation Cost Matrix ===\n\n"); text.append(costMatrix.toString()); } // Output test instance predictions only if (printClassifications) { DataSource source = testSource; predsBuff = new StringBuffer(); // no test set -> use train set if (source == null && noCrossValidation) { source = trainSource; predsBuff.append("\n=== Predictions on training data ===\n\n"); } else { predsBuff.append("\n=== Predictions on test data ===\n\n"); } if (source != null) { /* return printClassifications(classifierClassifications, new Instances(template, 0), source, actualClassIndex + 1, attributesToOutput, printDistribution); */ printClassifications(classifierClassifications, new Instances(template, 0), source, actualClassIndex + 1, attributesToOutput, printDistribution, predsBuff); // return predsText.toString(); } } // Compute error estimate from training data if ((trainStatistics) && (trainSetPresent)) { if ((classifier instanceof UpdateableClassifier) && (testSetPresent) && (costMatrix == null)) { // Classifier was trained incrementally, so we have to // reset the source. trainSource.reset(); // Incremental testing train = trainSource.getStructure(actualClassIndex); testTimeStart = System.currentTimeMillis(); Instance trainInst; while (trainSource.hasMoreElements(train)) { trainInst = trainSource.nextElement(train); trainingEvaluation.evaluateModelOnce((Classifier) classifier, trainInst); } testTimeElapsed = System.currentTimeMillis() - testTimeStart; } else { testTimeStart = System.currentTimeMillis(); trainingEvaluation.evaluateModel(classifier, trainSource.getDataSet(actualClassIndex)); testTimeElapsed = System.currentTimeMillis() - testTimeStart; } // Print the results of the training evaluation if (printMargins) { return trainingEvaluation.toCumulativeMarginDistributionString(); } else { if (!printClassifications) { text.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); if (splitPercentage > 0) text.append("\nTime taken to test model on training split: "); else text.append("\nTime taken to test model on training data: "); text.append(Utils.doubleToString(testTimeElapsed / 1000.0, 2) + " seconds"); if (splitPercentage > 0) text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " split ===\n", printComplexityStatistics)); else text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training" + " data ===\n", printComplexityStatistics)); if (template.classAttribute().isNominal()) { if (classStatistics) { text.append("\n\n" + trainingEvaluation.toClassDetailsString()); } if (!noCrossValidation) text.append("\n\n" + trainingEvaluation.toMatrixString()); } } } } // Compute proper error estimates if (testSource != null) { // Testing is on the supplied test data testSource.reset(); test = testSource.getStructure(test.classIndex()); Instance testInst; while (testSource.hasMoreElements(test)) { testInst = testSource.nextElement(test); testingEvaluation.evaluateModelOnceAndRecordPrediction((Classifier) classifier, testInst); } if (splitPercentage > 0) { if (!printClassifications) { text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test split ===\n", printComplexityStatistics)); } } else { if (!printClassifications) { text.append("\n\n" + testingEvaluation.toSummaryString("=== Error on test data ===\n", printComplexityStatistics)); } } } else if (trainSource != null) { if (!noCrossValidation) { // Testing is via cross-validation on training data Random random = new Random(seed); // use untrained (!) classifier for cross-validation classifier = Classifier.makeCopy(classifierBackup); if (!printClassifications) { testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex), folds, random); if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation.toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation.toSummaryString( "=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } } else { predsBuff = new StringBuffer(); predsBuff.append("\n=== Predictions under cross-validation ===\n\n"); testingEvaluation.crossValidateModel(classifier, trainSource.getDataSet(actualClassIndex), folds, random, predsBuff, attributesToOutput, new Boolean(printDistribution)); /* if (template.classAttribute().isNumeric()) { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Cross-validation ===\n", printComplexityStatistics)); } else { text.append("\n\n\n" + testingEvaluation. toSummaryString("=== Stratified " + "cross-validation ===\n", printComplexityStatistics)); } */ } } } if (template.classAttribute().isNominal()) { if (classStatistics && !noCrossValidation && !printClassifications) { text.append("\n\n" + testingEvaluation.toClassDetailsString()); } if (!noCrossValidation && !printClassifications) text.append("\n\n" + testingEvaluation.toMatrixString()); } // predictions from cross-validation? if (predsBuff != null) { text.append("\n" + predsBuff); } if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) { int labelIndex = 0; if (thresholdLabel.length() != 0) labelIndex = template.classAttribute().indexOfValue(thresholdLabel); if (labelIndex == -1) throw new IllegalArgumentException("Class label '" + thresholdLabel + "' is unknown!"); ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex); DataSink.write(thresholdFile, result); } return text.toString(); }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied stream learner. * @return the AUC as a double value.//from www .j a v a 2 s . c o m */ private static double validate5x2CVStream() { try { // Other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); Distribution testDistribution = new Distribution(test); ArffSaver trainSaver = new ArffSaver(); trainSaver.setInstances(train); trainSaver.setFile(new File(trainPath)); trainSaver.writeBatch(); ArffSaver testSaver = new ArffSaver(); testSaver.setInstances(test); double[][] dist = testDistribution.matrix(); int negativeClassSize = (int) dist[0][0]; int positiveClassSize = (int) dist[0][1]; double balance = (double) positiveClassSize / (double) negativeClassSize; String tempTestPath = testPath.replace(".arff", "_" + positiveClassSize + "_" + negativeClassSize + "_" + balance + "_1.0.arff");// [Test-n-Set-n]_[+]_[-]_[K]_[L]; testSaver.setFile(new File(tempTestPath)); testSaver.writeBatch(); ARFFFile file = new ARFFFile(tempTestPath, CLASS_INDEX, new DebugLogger(false)); file.createMetaData(); HoeffdingTreeTester streamClassifier = new HoeffdingTreeTester(trainPath, tempTestPath, CLASS_INDEX, new String[] { "0", "1" }, new DebugLogger(true)); streamClassifier.train(); System.in.read(); //AUC_SUM += streamClassifier.getROCExternalData("",(int)testDistribution.perClass(1),(int)testDistribution.perClass(0)); streamClassifier.testStatic(homeDirectory + "/FuckSakeTest.txt"); String[] files = Common.getFilePaths(scratch); for (int j = 0; j < files.length; j++) Common.fileDelete(files[j]); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); e.printStackTrace(); return 0; } }
From source file:cs.man.ac.uk.classifiers.GetAUC.java
License:Open Source License
/** * Computes the AUC for the supplied learner. * @return the AUC as a double value.// w ww . j ava 2 s . c om */ @SuppressWarnings("unused") private static double validate5x2CV() { try { // other options int runs = 5; int folds = 2; double AUC_SUM = 0; // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { System.out.println("Stratifying..."); randData.stratify(folds); } Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier String[] options = { "-U", "-A" }; J48 classifier = new J48(); //HTree classifier = new HTree(); classifier.setOptions(options); classifier.buildClassifier(train); eval.evaluateModel(classifier, test); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve vmc = new ThresholdVisualizePanel(); AUC_SUM += ThresholdCurve.getROCArea(result); System.out.println("AUC: " + ThresholdCurve.getROCArea(result) + " \tAUC SUM: " + AUC_SUM); } } return AUC_SUM / ((double) runs * (double) folds); } catch (Exception e) { System.out.println("Exception validating data!"); return 0; } }
From source file:de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator.java
License:Open Source License
public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception { // Set parameters int folds = 10; Classifier baseClassifier = getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random); // Perform cross-validation Instances predictedData = null;//from w w w .ja v a 2s . c o m Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(filteredClassifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(filteredClassifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } // Prepare output classification String[] scores = new String[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; String value = predInst.stringValue(predInst.attribute(valueIdx)); scores[id] = value; } // Output StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile( new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"), sb.toString()); }
From source file:de.unidue.langtech.grading.tc.LearningCurveTask.java
License:Open Source License
@Override public void execute(TaskContext aContext) throws Exception { boolean multiLabel = false; for (Integer numberInstances : NUMBER_OF_TRAINING_INSTANCES) { for (int iteration = 0; iteration < ITERATIONS; iteration++) { File arffFileTrain = new File( aContext.getStorageLocation(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY) .getPath() + "/" + TRAINING_DATA_FILENAME); File arffFileTest = new File( aContext.getStorageLocation(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY).getPath() + "/" + TRAINING_DATA_FILENAME); Instances trainData = TaskUtils.getInstances(arffFileTrain, multiLabel); Instances testData = TaskUtils.getInstances(arffFileTest, multiLabel); if (numberInstances > trainData.size()) { continue; }/*from ww w. j av a 2 s. co m*/ Classifier cl = AbstractClassifier.forName(classificationArguments.get(0), classificationArguments.subList(1, classificationArguments.size()).toArray(new String[0])); Instances copyTestData = new Instances(testData); trainData = WekaUtils.removeOutcomeId(trainData, multiLabel); testData = WekaUtils.removeOutcomeId(testData, multiLabel); Random generator = new Random(); generator.setSeed(System.nanoTime()); trainData.randomize(generator); // remove fraction of training data that should not be used for training for (int i = trainData.size() - 1; i >= numberInstances; i--) { trainData.delete(i); } // file to hold prediction results File evalOutput = new File( aContext.getStorageLocation(TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE).getPath() + "/" + EVALUATION_DATA_FILENAME + "_" + numberInstances + "_" + iteration); // train the classifier on the train set split - not necessary in multilabel setup, but // in single label setup cl.buildClassifier(trainData); weka.core.SerializationHelper.write(evalOutput.getAbsolutePath(), WekaUtils.getEvaluationSinglelabel(cl, trainData, testData)); testData = WekaUtils.getPredictionInstancesSingleLabel(testData, cl); testData = WekaUtils.addOutcomeId(testData, copyTestData, false); // // Write out the predictions // DataSink.write(aContext.getStorageLocation(TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE) // .getAbsolutePath() + "/" + PREDICTIONS_FILENAME + "_" + trainPercent, testData); } } }