List of usage examples for weka.classifiers.meta FilteredClassifier FilteredClassifier
public FilteredClassifier()
From source file:dkpro.similarity.experiments.rte.util.Evaluator.java
License:Open Source License
public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset) throws Exception { Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the train instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" }); Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff"); train.setClassIndex(train.numAttributes() - 1); // Add IDs to the test instances and get the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + testDataset.toString() + ".arff", "-o", MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" }); Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff"); test.setClassIndex(test.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data test.randomize(random);//from w w w . j av a2 s . co m // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Prepare the output buffer AbstractOutput output = new PlainText(); output.setBuffer(new StringBuffer()); output.setHeader(test); output.setAttributes("first"); Evaluation eval = new Evaluation(train); eval.evaluateModel(filteredClassifier, test, output); // Convert predictions to CSV // Format: inst#, actual, predicted, error, probability, (ID) String[] scores = new String[new Double(eval.numInstances()).intValue()]; double[] probabilities = new double[new Double(eval.numInstances()).intValue()]; for (String line : output.getBuffer().toString().split("\n")) { String[] linesplit = line.split("\\s+"); // If there's been an error, the length of linesplit is 6, otherwise 5, // due to the error flag "+" int id; String expectedValue, classification; double probability; if (line.contains("+")) { id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[5]); } else { id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1)); expectedValue = linesplit[2].substring(2); classification = linesplit[3].substring(2); probability = Double.parseDouble(linesplit[4]); } scores[id - 1] = classification; probabilities[id - 1] = probability; } System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); // Output classifications StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"), sb.toString()); // Output probabilities sb = new StringBuilder(); for (Double probability : probabilities) sb.append(probability.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"), sb.toString()); // Output predictions FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"), output.getBuffer().toString()); // Output meta information sb = new StringBuilder(); sb.append(classifier.toString() + LF); sb.append(eval.toSummaryString() + LF); sb.append(eval.toMatrixString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"), sb.toString()); }
From source file:dkpro.similarity.experiments.rte.util.Evaluator.java
License:Open Source License
public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset) throws Exception { // Set parameters int folds = 10; Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random);/*from w w w . j a va 2 s . com*/ // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter // Filter logFilter = new LogFilter(); // logFilter.setInputFormat(train); // train = Filter.useFilter(train, logFilter); // logFilter.setInputFormat(test); // test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(filteredClassifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(classifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); // Prepare output scores String[] scores = new String[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; String value = predInst.stringValue(predInst.attribute(valueIdx)); scores[id] = value; } // Output classifications StringBuilder sb = new StringBuilder(); for (String score : scores) sb.append(score.toString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".csv"), sb.toString()); // Output prediction arff DataSink.write(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".predicted.arff", predictedData); // Output meta information sb = new StringBuilder(); sb.append(baseClassifier.toString() + LF); sb.append(eval.toSummaryString() + LF); sb.append(eval.toMatrixString() + LF); FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + dataset.toString() + ".meta.txt"), sb.toString()); }
From source file:dkpro.similarity.experiments.sts2013.util.Evaluator.java
License:Open Source License
public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception { for (Dataset dataset : datasets) { // Set parameters int folds = 10; Classifier baseClassifier = new LinearRegression(); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff" }); Instances data = DataSource.read( MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff"); data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random);//from w w w . ja v a2s . com // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter Filter logFilter = new LogFilter(); logFilter.setInputFormat(train); train = Filter.useFilter(train, logFilter); logFilter.setInputFormat(test); test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(classifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(classifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) { predictedData = new Instances(pred, 0); } for (int j = 0; j < pred.numInstances(); j++) { predictedData.add(pred.instance(j)); } } // Prepare output scores double[] scores = new double[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; double value = predInst.value(predInst.attribute(valueIdx)); scores[id] = value; // Limit to interval [0;5] if (scores[id] > 5.0) { scores[id] = 5.0; } if (scores[id] < 0.0) { scores[id] = 0.0; } } // Output StringBuilder sb = new StringBuilder(); for (Double score : scores) { sb.append(score.toString() + LF); } FileUtils.writeStringToFile( new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"), sb.toString()); } }
From source file:dkpro.similarity.experiments.sts2013baseline.util.Evaluator.java
License:Open Source License
public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception { for (Dataset dataset : datasets) { // Set parameters int folds = 10; Classifier baseClassifier = new LinearRegression(); // Set up the random number generator long seed = new Date().getTime(); Random random = new Random(seed); // Add IDs to the instances AddID.main(new String[] { "-i", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff" }); String location = MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff"; Instances data = DataSource.read(location); if (data == null) { throw new IOException("Could not load data from: " + location); }/*from ww w . j av a 2s . c om*/ data.setClassIndex(data.numAttributes() - 1); // Instantiate the Remove filter Remove removeIDFilter = new Remove(); removeIDFilter.setAttributeIndices("first"); // Randomize the data data.randomize(random); // Perform cross-validation Instances predictedData = null; Evaluation eval = new Evaluation(data); for (int n = 0; n < folds; n++) { Instances train = data.trainCV(folds, n, random); Instances test = data.testCV(folds, n); // Apply log filter Filter logFilter = new LogFilter(); logFilter.setInputFormat(train); train = Filter.useFilter(train, logFilter); logFilter.setInputFormat(test); test = Filter.useFilter(test, logFilter); // Copy the classifier Classifier classifier = AbstractClassifier.makeCopy(baseClassifier); // Instantiate the FilteredClassifier FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(removeIDFilter); filteredClassifier.setClassifier(classifier); // Build the classifier filteredClassifier.buildClassifier(train); // Evaluate eval.evaluateModel(classifier, test); // Add predictions AddClassification filter = new AddClassification(); filter.setClassifier(classifier); filter.setOutputClassification(true); filter.setOutputDistribution(false); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // performs predictions on test set if (predictedData == null) { predictedData = new Instances(pred, 0); } for (int j = 0; j < pred.numInstances(); j++) { predictedData.add(pred.instance(j)); } } // Prepare output scores double[] scores = new double[predictedData.numInstances()]; for (Instance predInst : predictedData) { int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1; int valueIdx = predictedData.numAttributes() - 2; double value = predInst.value(predInst.attribute(valueIdx)); scores[id] = value; // Limit to interval [0;5] if (scores[id] > 5.0) { scores[id] = 5.0; } if (scores[id] < 0.0) { scores[id] = 0.0; } } // Output StringBuilder sb = new StringBuilder(); for (Double score : scores) { sb.append(score.toString() + LF); } FileUtils.writeStringToFile( new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"), sb.toString()); } }
From source file:elh.eus.absa.WekaWrapper.java
License:Open Source License
/** * @param traindata/* ww w . ja va2 s.co m*/ * @param testdata * @param id : whether the first attribute represents de instance id and should be filtered out for classifying * @throws Exception */ public WekaWrapper(Instances traindata, Instances testdata, boolean id) throws Exception { // classifier weka.classifiers.functions.SMO SVM = new weka.classifiers.functions.SMO(); SVM.setOptions(weka.core.Utils.splitOptions("-C 1.0 -L 0.0010 -P 1.0E-12 -N 0 -V -1 -W 1 " + "-K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\"")); setTraindata(traindata); setTestdata(testdata); // first attribute reflects instance id, delete it when building the classifier if (id) { //filter Remove rm = new Remove(); rm.setAttributeIndices("1"); // remove 1st attribute // meta-classifier FilteredClassifier fc = new FilteredClassifier(); fc.setFilter(rm); fc.setClassifier(SVM); setMLclass(fc); } else { setMLclass(SVM); } }
From source file:elh.eus.absa.WekaWrapper.java
License:Open Source License
public void filterAttribute(String index) throws Exception { //filter/* ww w . j a va2s .c om*/ Remove rm = new Remove(); rm.setAttributeIndices(index); // remove 1st attribute indexes start from 1 // meta-classifier FilteredClassifier fc = new FilteredClassifier(); fc.setFilter(rm); fc.setClassifier(this.MLclass); setMLclass(fc); }
From source file:gov.va.chir.tagline.TagLineEvaluator.java
License:Open Source License
private void setup(final Collection<Document> documents, final Feature... features) throws IOException { if (!DatasetUtil.hasLabels(documents)) { throw new IllegalArgumentException("All lines for training must have a label."); }/*from w w w .jav a 2 s .com*/ // Setup extractor for feature calculation final Extractor extractor = new Extractor(); if (features != null && features.length > 0) { extractor.addFeatures(features); } else { extractor.addFeatures(Extractor.getDefaultFeatures()); } // Setup any features that require the entire corpus extractor.setupCorpusProcessors(documents); // Calculate features at both document and line level for (Document document : documents) { extractor.calculateFeatureValues(document); } // Create dataset instances = DatasetUtil.createDataset(documents); // Remove IDs from dataset final Remove remove = new Remove(); remove.setAttributeIndicesArray(new int[] { instances.attribute(DatasetUtil.DOC_ID).index(), instances.attribute(DatasetUtil.LINE_ID).index() }); fc = new FilteredClassifier(); fc.setFilter(remove); }
From source file:gov.va.chir.tagline.TagLineScorer.java
License:Open Source License
public TagLineScorer(final TagLineModel tagLineModel) throws Exception { this.tagLineModel = tagLineModel; // Setup extractor for feature calculation extractor = new Extractor(); extractor.addFeatures(this.tagLineModel.getFeatures()); // @TODO - check if this code is necessary AND if this means a classifier can only be used one (since we remove stuff) classAttr = tagLineModel.getHeader().attribute(DatasetUtil.LABEL); this.tagLineModel.getHeader().setClass(classAttr); lineIdAttr = this.tagLineModel.getHeader().attribute(DatasetUtil.LINE_ID); final Attribute docIdAttr = this.tagLineModel.getHeader().attribute(DatasetUtil.DOC_ID); // Remove IDs from dataset (match training) final Remove remove = new Remove(); remove.setAttributeIndicesArray(new int[] { docIdAttr.index(), lineIdAttr.index() }); remove.setInputFormat(this.tagLineModel.getHeader()); fc = new FilteredClassifier(); fc.setFilter(remove);//from w w w .jav a2 s. c o m fc.setClassifier(this.tagLineModel.getModel()); }
From source file:gov.va.chir.tagline.TagLineTrainer.java
License:Open Source License
public void train(final Collection<Document> documents, final Feature... features) throws Exception { if (!DatasetUtil.hasLabels(documents)) { throw new IllegalArgumentException("All lines for training must have a label."); }/*w w w . java 2 s.com*/ // Setup extractor for feature calculation extractor = new Extractor(); if (features != null && features.length > 0) { extractor.addFeatures(features); } else { extractor.addFeatures(Extractor.getDefaultFeatures()); } // Setup any features that require the entire corpus extractor.setupCorpusProcessors(documents); // Calculate features at both document and line level for (Document document : documents) { extractor.calculateFeatureValues(document); } // Create dataset instances = DatasetUtil.createDataset(documents); // Remove IDs from dataset final Remove remove = new Remove(); remove.setAttributeIndicesArray(new int[] { instances.attribute(DatasetUtil.DOC_ID).index(), instances.attribute(DatasetUtil.LINE_ID).index() }); final FilteredClassifier fc = new FilteredClassifier(); fc.setFilter(remove); fc.setClassifier(tagLineModel.getModel()); // Train model fc.buildClassifier(instances); tagLineModel.setModel(fc.getClassifier()); }
From source file:kea.KEAFilter.java
License:Open Source License
/** * Builds the classifier./*w w w . j av a 2 s . com*/ */ private void buildClassifier() throws Exception { // Generate input format for classifier FastVector atts = new FastVector(); for (int i = 0; i < getInputFormat().numAttributes(); i++) { if (i == m_DocumentAtt) { atts.addElement(new Attribute("TFxIDF")); atts.addElement(new Attribute("First_occurrence")); if (m_KFused) { atts.addElement(new Attribute("Keyphrase_frequency")); } } else if (i == m_KeyphrasesAtt) { FastVector vals = new FastVector(2); vals.addElement("False"); vals.addElement("True"); atts.addElement(new Attribute("Keyphrase?", vals)); } } m_ClassifierData = new Instances("ClassifierData", atts, 0); m_ClassifierData.setClassIndex(m_NumFeatures); if (m_Debug) { System.err.println("--- Converting instances for classifier"); } // Convert pending input instances into data for classifier for (int i = 0; i < getInputFormat().numInstances(); i++) { Instance current = getInputFormat().instance(i); // Get the key phrases for the document String keyphrases = current.stringValue(m_KeyphrasesAtt); HashMap hashKeyphrases = getGivenKeyphrases(keyphrases, false); HashMap hashKeysEval = getGivenKeyphrases(keyphrases, true); // Get the phrases for the document HashMap hash = new HashMap(); int length = getPhrases(hash, current.stringValue(m_DocumentAtt)); // Compute the feature values for each phrase and // add the instance to the data for the classifier Iterator it = hash.keySet().iterator(); while (it.hasNext()) { String phrase = (String) it.next(); FastVector phraseInfo = (FastVector) hash.get(phrase); double[] vals = featVals(phrase, phraseInfo, true, hashKeysEval, hashKeyphrases, length); Instance inst = new Instance(current.weight(), vals); m_ClassifierData.add(inst); } } if (m_Debug) { System.err.println("--- Building classifier"); } // Build classifier FilteredClassifier fclass = new FilteredClassifier(); fclass.setClassifier(new NaiveBayesSimple()); fclass.setFilter(new Discretize()); m_Classifier = fclass; m_Classifier.buildClassifier(m_ClassifierData); if (m_Debug) { System.err.println(m_Classifier); } // Save space m_ClassifierData = new Instances(m_ClassifierData, 0); }