List of usage examples for weka.classifiers Classifier buildClassifier
public abstract void buildClassifier(Instances data) throws Exception;
From source file:adams.flow.transformer.WekaClassifierOptimizer.java
License:Open Source License
/** * Executes the flow item./* ww w .j a v a 2 s . co m*/ * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; Instances data; weka.classifiers.Classifier cls; weka.classifiers.Classifier best; result = null; try { // determine best classifier data = (Instances) m_InputToken.getPayload(); cls = (weka.classifiers.Classifier) OptionUtils.shallowCopy(m_Optimizer); cls.buildClassifier(data); if (cls instanceof GridSearch) { best = ((GridSearch) cls).getBestClassifier(); } else if (cls instanceof MultiSearch) { best = ((MultiSearch) cls).getBestClassifier(); } else { best = null; result = "Unhandled optimizer: " + m_Optimizer.getClass().getName(); } // broadcast result if (best != null) m_OutputToken = new Token(best); } catch (Exception e) { m_OutputToken = null; result = handleException("Failed to optimize: ", e); } return result; }
From source file:adams.flow.transformer.WekaCrossValidationEvaluator.java
License:Open Source License
/** * Executes the flow item.//from www . j a va 2 s .co m * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; Instances data; weka.classifiers.Classifier cls; weka.classifiers.Classifier model; int[] indices; indices = null; data = null; try { // evaluate classifier cls = getClassifierInstance(); if (cls == null) throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!"); if (isLoggingEnabled()) getLogger().info(OptionUtils.getCommandLine(cls)); data = (Instances) m_InputToken.getPayload(); m_CrossValidation = new WekaCrossValidationExecution(); m_CrossValidation.setJobRunnerSetup(m_JobRunnerSetup); m_CrossValidation.setClassifier(cls); m_CrossValidation.setData(data); m_CrossValidation.setFolds(m_Folds); m_CrossValidation.setSeed(m_Seed); m_CrossValidation.setUseViews(m_UseViews); m_CrossValidation.setDiscardPredictions(m_DiscardPredictions); m_CrossValidation.setNumThreads(m_NumThreads); m_CrossValidation.setOutput(m_Output); m_CrossValidation.setGenerator((CrossValidationFoldGenerator) OptionUtils.shallowCopy(m_Generator)); result = m_CrossValidation.execute(); if (!m_CrossValidation.isStopped()) { indices = m_CrossValidation.getOriginalIndices(); if (m_CrossValidation.isSingleThreaded()) { if (m_Output instanceof Null) { m_OutputToken = new Token(new WekaEvaluationContainer(m_CrossValidation.getEvaluation())); } else { if (m_CrossValidation.getOutputBuffer() != null) m_OutputBuffer.append(m_CrossValidation.getOutputBuffer().toString()); if (m_AlwaysUseContainer || m_FinalModel) m_OutputToken = new Token(new WekaEvaluationContainer(m_CrossValidation.getEvaluation(), null, m_Output.getBuffer().toString())); else m_OutputToken = new Token(m_Output.getBuffer().toString()); } } else { m_OutputToken = new Token(new WekaEvaluationContainer(m_CrossValidation.getEvaluation())); } // build model if (m_OutputToken.hasPayload(WekaEvaluationContainer.class)) { if (m_FinalModel) { model = ObjectCopyHelper.copyObject(cls); model.buildClassifier(data); m_OutputToken.getPayload(WekaEvaluationContainer.class) .setValue(WekaEvaluationContainer.VALUE_MODEL, model); } } } } catch (Exception e) { m_OutputToken = null; result = handleException("Failed to cross-validate classifier: ", e); } if (m_OutputToken != null) { if (m_OutputToken.hasPayload(WekaEvaluationContainer.class)) { m_OutputToken.getPayload(WekaEvaluationContainer.class) .setValue(WekaEvaluationContainer.VALUE_TESTDATA, data); if (indices != null) m_OutputToken.getPayload(WekaEvaluationContainer.class) .setValue(WekaEvaluationContainer.VALUE_ORIGINALINDICES, indices); } updateProvenance(m_OutputToken); } return result; }
From source file:adams.flow.transformer.WekaTrainClassifier.java
License:Open Source License
/** * Executes the flow item.//www .j a v a 2 s . c o m * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; Instances data; Instance inst; weka.classifiers.Classifier cls; result = null; try { cls = null; if ((m_InputToken != null) && (m_InputToken.getPayload() instanceof Instances)) { cls = getClassifierInstance(); data = (Instances) m_InputToken.getPayload(); cls.buildClassifier(data); m_OutputToken = new Token(new WekaModelContainer(cls, new Instances(data, 0), data)); } else if ((m_InputToken != null) && (m_InputToken.getPayload() instanceof Instance)) { if (m_IncrementalClassifier == null) { cls = getClassifierInstance(); if (!(cls instanceof UpdateableClassifier)) result = m_Classifier + "/" + cls.getClass().getName() + " is not an incremental classifier!"; } if (result == null) { inst = (Instance) m_InputToken.getPayload(); if (m_IncrementalClassifier == null) { m_IncrementalClassifier = cls; if (m_SkipBuild) { ((UpdateableClassifier) m_IncrementalClassifier).updateClassifier(inst); } else { data = new Instances(inst.dataset(), 1); data.add((Instance) inst.copy()); m_IncrementalClassifier.buildClassifier(data); } } else { ((UpdateableClassifier) m_IncrementalClassifier).updateClassifier(inst); } m_OutputToken = new Token( new WekaModelContainer(m_IncrementalClassifier, new Instances(inst.dataset(), 0))); } } } catch (Exception e) { m_OutputToken = null; result = handleException("Failed to process data:", e); } if (m_OutputToken != null) updateProvenance(m_OutputToken); return result; }
From source file:adams.flow.transformer.WekaTrainTestSetEvaluator.java
License:Open Source License
/** * Executes the flow item./* ww w . ja v a2 s . c o m*/ * * @return null if everything is fine, otherwise error message */ @Override protected String doExecute() { String result; Instances train; Instances test; weka.classifiers.Classifier cls; Evaluation eval; WekaTrainTestSetContainer cont; result = null; test = null; try { // cross-validate classifier cls = getClassifierInstance(); if (cls == null) throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!"); cont = (WekaTrainTestSetContainer) m_InputToken.getPayload(); train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN); test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST); cls.buildClassifier(train); initOutputBuffer(); m_Output.setHeader(train); eval = new Evaluation(train); eval.setDiscardPredictions(m_DiscardPredictions); eval.evaluateModel(cls, test, m_Output); // broadcast result if (m_Output instanceof Null) { m_OutputToken = new Token(new WekaEvaluationContainer(eval, cls)); } else { if (m_AlwaysUseContainer) m_OutputToken = new Token( new WekaEvaluationContainer(eval, cls, m_Output.getBuffer().toString())); else m_OutputToken = new Token(m_Output.getBuffer().toString()); } } catch (Exception e) { m_OutputToken = null; result = handleException("Failed to evaluate: ", e); } if (m_OutputToken != null) { if (m_OutputToken.getPayload() instanceof WekaEvaluationContainer) { if (test != null) ((WekaEvaluationContainer) m_OutputToken.getPayload()) .setValue(WekaEvaluationContainer.VALUE_TESTDATA, test); } updateProvenance(m_OutputToken); } return result; }
From source file:adams.ml.model.classification.WekaClassifier.java
License:Open Source License
/** * Builds a model from the data.//from w ww. j av a 2s.c om * * @param data the data to use for building the model * @return the generated model * @throws Exception if the build fails */ @Override protected ClassificationModel doBuildModel(Dataset data) throws Exception { Instances inst; weka.classifiers.Classifier classifier; inst = WekaConverter.toInstances(data); classifier = (weka.classifiers.Classifier) OptionUtils.shallowCopy(m_Classifier); if (classifier == null) throw new Exception( "Failed to create shallow copy of classifier: " + OptionUtils.getCommandLine(m_Classifier)); classifier.buildClassifier(inst); return new WekaClassificationModel(classifier, data, inst); }
From source file:adams.ml.model.regression.WekaRegressor.java
License:Open Source License
/** * Builds a model from the data./* w ww .j av a 2 s . c o m*/ * * @param data the data to use for building the model * @return the generated model * @throws Exception if the build fails */ @Override protected RegressionModel doBuildModel(Dataset data) throws Exception { Instances inst; weka.classifiers.Classifier classifier; inst = WekaConverter.toInstances(data); classifier = (weka.classifiers.Classifier) OptionUtils.shallowCopy(m_Classifier); if (classifier == null) throw new Exception( "Failed to create shallow copy of classifier: " + OptionUtils.getCommandLine(m_Classifier)); classifier.buildClassifier(inst); return new WekaRegressionModel(classifier, data, inst); }
From source file:adams.multiprocess.WekaCrossValidationExecution.java
License:Open Source License
/** * Executes the flow item.// w w w . j ava 2s .com * * @return null if everything is fine, otherwise error message */ public String execute() { MessageCollection result; Evaluation eval; AggregateEvaluations evalAgg; int folds; CrossValidationFoldGenerator generator; JobList<WekaCrossValidationJob> list; WekaCrossValidationJob job; WekaTrainTestSetContainer cont; int i; int current; int[] indices; Instances train; Instances test; Classifier cls; result = new MessageCollection(); indices = null; m_Evaluation = null; m_Evaluations = null; try { // evaluate classifier if (m_Classifier == null) throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!"); if (isLoggingEnabled()) getLogger().info(OptionUtils.getCommandLine(m_Classifier)); m_ActualNumThreads = Performance.determineNumThreads(m_NumThreads); generator = (CrossValidationFoldGenerator) OptionUtils.shallowCopy(m_Generator); generator.setData(m_Data); generator.setNumFolds(m_Folds); generator.setSeed(m_Seed); generator.setStratify(true); generator.setUseViews(m_UseViews); generator.initializeIterator(); folds = generator.getActualNumFolds(); if ((m_ActualNumThreads == 1) && !m_SeparateFolds) { initOutputBuffer(); if (m_Output != null) { m_Output.setHeader(m_Data); m_Output.printHeader(); } eval = new Evaluation(m_Data); eval.setDiscardPredictions(m_DiscardPredictions); current = 0; while (generator.hasNext()) { if (isStopped()) break; if (m_StatusMessageHandler != null) m_StatusMessageHandler.showStatus("Fold " + current + "/" + folds + ": '" + m_Data.relationName() + "' using " + OptionUtils.getCommandLine(m_Classifier)); cont = generator.next(); train = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN); test = (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST); cls = (Classifier) OptionUtils.shallowCopy(m_Classifier); cls.buildClassifier(train); eval.setPriors(train); eval.evaluateModel(cls, test, m_Output); current++; } if (m_Output != null) m_Output.printFooter(); if (!isStopped()) m_Evaluation = eval; } else { if (m_DiscardPredictions) throw new IllegalStateException( "Cannot discard predictions in parallel mode, as they are used for aggregating the statistics!"); if (m_JobRunnerSetup == null) m_JobRunner = new LocalJobRunner<WekaCrossValidationJob>(); else m_JobRunner = m_JobRunnerSetup.newInstance(); if (m_JobRunner instanceof ThreadLimiter) ((ThreadLimiter) m_JobRunner).setNumThreads(m_NumThreads); list = new JobList<>(); while (generator.hasNext()) { cont = generator.next(); job = new WekaCrossValidationJob((Classifier) OptionUtils.shallowCopy(m_Classifier), (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TRAIN), (Instances) cont.getValue(WekaTrainTestSetContainer.VALUE_TEST), (Integer) cont.getValue(WekaTrainTestSetContainer.VALUE_FOLD_NUMBER), m_DiscardPredictions, m_StatusMessageHandler); list.add(job); } m_JobRunner.add(list); m_JobRunner.start(); m_JobRunner.stop(); // aggregate data if (!isStopped()) { evalAgg = new AggregateEvaluations(); m_Evaluations = new Evaluation[m_JobRunner.getJobs().size()]; for (i = 0; i < m_JobRunner.getJobs().size(); i++) { job = (WekaCrossValidationJob) m_JobRunner.getJobs().get(i); if (job.getEvaluation() == null) { result.add("Fold #" + (i + 1) + " failed to evaluate" + (job.hasExecutionError() ? job.getExecutionError() : "?")); break; } evalAgg.add(job.getEvaluation()); m_Evaluations[i] = job.getEvaluation(); job.cleanUp(); } m_Evaluation = evalAgg.aggregated(); if (m_Evaluation == null) { if (evalAgg.hasLastError()) result.add(evalAgg.getLastError()); else result.add("Failed to aggregate evaluations!"); } } list.cleanUp(); m_JobRunner.cleanUp(); m_JobRunner = null; } if (!m_DiscardPredictions) indices = generator.crossValidationIndices(); } catch (Exception e) { result.add(Utils.handleException(this, "Failed to cross-validate classifier: ", e)); } m_OriginalIndices = indices; if (result.isEmpty()) return null; else return result.toString(); }
From source file:algoritmogeneticocluster.NewClass.java
public static Evaluation classify(Classifier model, Instances trainingSet, Instances testingSet) throws Exception { Evaluation evaluation = new Evaluation(trainingSet); model.buildClassifier(trainingSet); evaluation.evaluateModel(model, testingSet); return evaluation; }
From source file:asap.CrossValidation.java
/** * * @param dataInput//from w w w. j a va 2 s. co m * @param classIndex * @param removeIndices * @param cls * @param seed * @param folds * @param modelOutputFile * @return * @throws Exception */ public static String performCrossValidation(String dataInput, String classIndex, String removeIndices, AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception { PerformanceCounters.startTimer("cross-validation ST"); PerformanceCounters.startTimer("cross-validation init ST"); // loads data and set class index Instances data = DataSource.read(dataInput); String clsIndex = classIndex; switch (clsIndex) { case "first": data.setClassIndex(0); break; case "last": data.setClassIndex(data.numAttributes() - 1); break; default: try { data.setClassIndex(Integer.parseInt(clsIndex) - 1); } catch (NumberFormatException e) { data.setClassIndex(data.attribute(clsIndex).index()); } break; } Remove removeFilter = new Remove(); removeFilter.setAttributeIndices(removeIndices); removeFilter.setInputFormat(data); data = Filter.useFilter(data, removeFilter); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) { randData.stratify(folds); } // perform cross-validation and add predictions Evaluation eval = new Evaluation(randData); Instances trainSets[] = new Instances[folds]; Instances testSets[] = new Instances[folds]; Classifier foldCls[] = new Classifier[folds]; for (int n = 0; n < folds; n++) { trainSets[n] = randData.trainCV(folds, n); testSets[n] = randData.testCV(folds, n); foldCls[n] = AbstractClassifier.makeCopy(cls); } PerformanceCounters.stopTimer("cross-validation init ST"); PerformanceCounters.startTimer("cross-validation folds+train ST"); //paralelize!!:-------------------------------------------------------------- for (int n = 0; n < folds; n++) { Instances train = trainSets[n]; Instances test = testSets[n]; // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = foldCls[n]; clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } cls.buildClassifier(data); //until here!----------------------------------------------------------------- PerformanceCounters.stopTimer("cross-validation folds+train ST"); PerformanceCounters.startTimer("cross-validation post ST"); // output evaluation String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: " + folds + "\n" + "Seed: " + seed + "\n" + "\n" + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n"; if (!modelOutputFile.isEmpty()) { SerializationHelper.write(modelOutputFile, cls); } PerformanceCounters.stopTimer("cross-validation post ST"); PerformanceCounters.stopTimer("cross-validation ST"); return out; }
From source file:at.aictopic1.sentimentanalysis.machinelearning.impl.TwitterClassifer.java
public void trainModel() { Instances trainingData = loadTrainingData(); System.out.println("Class attribute: " + trainingData.classAttribute().toString()); // Partition dataset into training and test sets RemovePercentage filter = new RemovePercentage(); filter.setPercentage(10);//from ww w.ja v a2 s.c om Instances testData = null; // Split in training and testdata try { filter.setInputFormat(trainingData); testData = Filter.useFilter(trainingData, filter); } catch (Exception ex) { //Logger.getLogger(Trainer.class.getName()).log(Level.SEVERE, null, ex); System.out.println("Error getting testData: " + ex.toString()); } // Train the classifier Classifier model = (Classifier) new NaiveBayes(); try { // Save the model to fil // serialize model weka.core.SerializationHelper.write(modelDir + algorithm + ".model", model); } catch (Exception ex) { Logger.getLogger(TwitterClassifer.class.getName()).log(Level.SEVERE, null, ex); } // Set the local model this.trainedModel = model; try { model.buildClassifier(trainingData); } catch (Exception ex) { //Logger.getLogger(Trainer.class.getName()).log(Level.SEVERE, null, ex); System.out.println("Error training model: " + ex.toString()); } try { // Evaluate model Evaluation test = new Evaluation(trainingData); test.evaluateModel(model, testData); System.out.println(test.toSummaryString()); } catch (Exception ex) { //Logger.getLogger(Trainer.class.getName()).log(Level.SEVERE, null, ex); System.out.println("Error evaluating model: " + ex.toString()); } }