List of usage examples for weka.classifiers.bayes NaiveBayes NaiveBayes
NaiveBayes
From source file:matres.MatResUI.java
private void doClassification() { J48 m_treeResiko;//from ww w .j a v a 2 s . com J48 m_treeAksi; NaiveBayes m_nbResiko; NaiveBayes m_nbAksi; FastVector m_fvInstanceRisks; FastVector m_fvInstanceActions; InputStream isRiskTree = getClass().getResourceAsStream("data/ResikoTree.model"); InputStream isRiskNB = getClass().getResourceAsStream("data/ResikoNB.model"); InputStream isActionTree = getClass().getResourceAsStream("data/AksiTree.model"); InputStream isActionNB = getClass().getResourceAsStream("data/AksiNB.model"); m_treeResiko = new J48(); m_treeAksi = new J48(); m_nbResiko = new NaiveBayes(); m_nbAksi = new NaiveBayes(); try { //m_treeResiko = (J48) weka.core.SerializationHelper.read("ResikoTree.model"); m_treeResiko = (J48) weka.core.SerializationHelper.read(isRiskTree); //m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read("ResikoNB.model"); m_nbResiko = (NaiveBayes) weka.core.SerializationHelper.read(isRiskNB); //m_treeAksi = (J48) weka.core.SerializationHelper.read("AksiTree.model"); m_treeAksi = (J48) weka.core.SerializationHelper.read(isActionTree); //m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read("AksiNB.model"); m_nbAksi = (NaiveBayes) weka.core.SerializationHelper.read(isActionNB); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } System.out.println("Setting up an Instance..."); // Values for LIKELIHOOD OF OCCURRENCE FastVector fvLO = new FastVector(5); fvLO.addElement("> 10 in 1 year"); fvLO.addElement("1 - 10 in 1 year"); fvLO.addElement("1 in 1 year to 1 in 10 years"); fvLO.addElement("1 in 10 years to 1 in 100 years"); fvLO.addElement("1 in more than 100 years"); // Values for SAFETY FastVector fvSafety = new FastVector(5); fvSafety.addElement("near miss"); fvSafety.addElement("first aid injury, medical aid injury"); fvSafety.addElement("lost time injury / temporary disability"); fvSafety.addElement("permanent disability"); fvSafety.addElement("fatality"); // Values for EXTRA FUEL COST FastVector fvEFC = new FastVector(5); fvEFC.addElement("< 100 million rupiah"); fvEFC.addElement("0,1 - 1 billion rupiah"); fvEFC.addElement("1 - 10 billion rupiah"); fvEFC.addElement("10 - 100 billion rupiah"); fvEFC.addElement("> 100 billion rupiah"); // Values for SYSTEM RELIABILITY FastVector fvSR = new FastVector(5); fvSR.addElement("< 100 MWh"); fvSR.addElement("0,1 - 1 GWh"); fvSR.addElement("1 - 10 GWh"); fvSR.addElement("10 - 100 GWh"); fvSR.addElement("> 100 GWh"); // Values for EQUIPMENT COST FastVector fvEC = new FastVector(5); fvEC.addElement("< 50 million rupiah"); fvEC.addElement("50 - 500 million rupiah"); fvEC.addElement("0,5 - 5 billion rupiah"); fvEC.addElement("5 -50 billion rupiah"); fvEC.addElement("> 50 billion rupiah"); // Values for CUSTOMER SATISFACTION SOCIAL FACTOR FastVector fvCSSF = new FastVector(5); fvCSSF.addElement("Complaint from the VIP customer"); fvCSSF.addElement("Complaint from industrial customer"); fvCSSF.addElement("Complaint from community"); fvCSSF.addElement("Complaint from community that have potential riot"); fvCSSF.addElement("High potential riot"); // Values for RISK FastVector fvRisk = new FastVector(4); fvRisk.addElement("Low"); fvRisk.addElement("Moderate"); fvRisk.addElement("High"); fvRisk.addElement("Extreme"); // Values for ACTION FastVector fvAction = new FastVector(3); fvAction.addElement("Life Extension Program"); fvAction.addElement("Repair/Refurbish"); fvAction.addElement("Replace/Run to Fail + Investment"); // Defining Attributes, including Class(es) Attributes Attribute attrLO = new Attribute("LO", fvLO); Attribute attrSafety = new Attribute("Safety", fvSafety); Attribute attrEFC = new Attribute("EFC", fvEFC); Attribute attrSR = new Attribute("SR", fvSR); Attribute attrEC = new Attribute("EC", fvEC); Attribute attrCSSF = new Attribute("CSSF", fvCSSF); Attribute attrRisk = new Attribute("Risk", fvRisk); Attribute attrAction = new Attribute("Action", fvAction); m_fvInstanceRisks = new FastVector(7); m_fvInstanceRisks.addElement(attrLO); m_fvInstanceRisks.addElement(attrSafety); m_fvInstanceRisks.addElement(attrEFC); m_fvInstanceRisks.addElement(attrSR); m_fvInstanceRisks.addElement(attrEC); m_fvInstanceRisks.addElement(attrCSSF); m_fvInstanceRisks.addElement(attrRisk); m_fvInstanceActions = new FastVector(7); m_fvInstanceActions.addElement(attrLO); m_fvInstanceActions.addElement(attrSafety); m_fvInstanceActions.addElement(attrEFC); m_fvInstanceActions.addElement(attrSR); m_fvInstanceActions.addElement(attrEC); m_fvInstanceActions.addElement(attrCSSF); m_fvInstanceActions.addElement(attrAction); Instances dataRisk = new Instances("A-Risk-instance-to-classify", m_fvInstanceRisks, 0); Instances dataAction = new Instances("An-Action-instance-to-classify", m_fvInstanceActions, 0); double[] riskValues = new double[dataRisk.numAttributes()]; double[] actionValues = new double[dataRisk.numAttributes()]; String strLO = (String) m_cmbLO.getSelectedItem(); String strSafety = (String) m_cmbSafety.getSelectedItem(); String strEFC = (String) m_cmbEFC.getSelectedItem(); String strSR = (String) m_cmbSR.getSelectedItem(); String strEC = (String) m_cmbEC.getSelectedItem(); String strCSSF = (String) m_cmbCSSF.getSelectedItem(); Instance instRisk = new DenseInstance(7); Instance instAction = new DenseInstance(7); if (strLO.equals("-- none --")) { instRisk.setMissing(0); instAction.setMissing(0); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(0), strLO); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(0), strLO); } if (strSafety.equals("-- none --")) { instRisk.setMissing(1); instAction.setMissing(1); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(1), strSafety); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(1), strSafety); } if (strEFC.equals("-- none --")) { instRisk.setMissing(2); instAction.setMissing(2); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(2), strEFC); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(2), strEFC); } if (strSR.equals("-- none --")) { instRisk.setMissing(3); instAction.setMissing(3); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(3), strSR); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(3), strSR); } if (strEC.equals("-- none --")) { instRisk.setMissing(4); instAction.setMissing(4); } else { instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(4), strEC); instAction.setValue((Attribute) m_fvInstanceActions.elementAt(4), strEC); } if (strCSSF.equals("-- none --")) { instRisk.setMissing(5); instAction.setMissing(5); } else { instAction.setValue((Attribute) m_fvInstanceActions.elementAt(5), strCSSF); instRisk.setValue((Attribute) m_fvInstanceRisks.elementAt(5), strCSSF); } instRisk.setMissing(6); instAction.setMissing(6); dataRisk.add(instRisk); instRisk.setDataset(dataRisk); dataRisk.setClassIndex(dataRisk.numAttributes() - 1); dataAction.add(instAction); instAction.setDataset(dataAction); dataAction.setClassIndex(dataAction.numAttributes() - 1); System.out.println("Instance Resiko: " + dataRisk.instance(0)); System.out.println("\tNum Attributes : " + dataRisk.numAttributes()); System.out.println("\tNum instances : " + dataRisk.numInstances()); System.out.println("Instance Action: " + dataAction.instance(0)); System.out.println("\tNum Attributes : " + dataAction.numAttributes()); System.out.println("\tNum instances : " + dataAction.numInstances()); int classIndexRisk = 0; int classIndexAction = 0; String strClassRisk = null; String strClassAction = null; try { //classIndexRisk = (int) m_treeResiko.classifyInstance(dataRisk.instance(0)); classIndexRisk = (int) m_treeResiko.classifyInstance(instRisk); classIndexAction = (int) m_treeAksi.classifyInstance(instAction); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } strClassRisk = (String) fvRisk.elementAt(classIndexRisk); strClassAction = (String) fvAction.elementAt(classIndexAction); System.out.println("[Risk Class Index: " + classIndexRisk + " Class Label: " + strClassRisk + "]"); System.out.println("[Action Class Index: " + classIndexAction + " Class Label: " + strClassAction + "]"); if (strClassRisk != null) { m_txtRisk.setText(strClassRisk); } double[] riskDist = null; double[] actionDist = null; try { riskDist = m_nbResiko.distributionForInstance(dataRisk.instance(0)); actionDist = m_nbAksi.distributionForInstance(dataAction.instance(0)); String strProb; // set up RISK progress bars m_jBarRiskLow.setValue((int) (100 * riskDist[0])); m_jBarRiskLow.setString(String.format("%6.3f%%", 100 * riskDist[0])); m_jBarRiskModerate.setValue((int) (100 * riskDist[1])); m_jBarRiskModerate.setString(String.format("%6.3f%%", 100 * riskDist[1])); m_jBarRiskHigh.setValue((int) (100 * riskDist[2])); m_jBarRiskHigh.setString(String.format("%6.3f%%", 100 * riskDist[2])); m_jBarRiskExtreme.setValue((int) (100 * riskDist[3])); m_jBarRiskExtreme.setString(String.format("%6.3f%%", 100 * riskDist[3])); } catch (Exception ex) { Logger.getLogger(MatResUI.class.getName()).log(Level.SEVERE, null, ex); } double predictedProb = 0.0; String predictedClass = ""; // Loop over all the prediction labels in the distribution. for (int predictionDistributionIndex = 0; predictionDistributionIndex < riskDist.length; predictionDistributionIndex++) { // Get this distribution index's class label. String predictionDistributionIndexAsClassLabel = dataRisk.classAttribute() .value(predictionDistributionIndex); int classIndex = dataRisk.classAttribute().indexOfValue(predictionDistributionIndexAsClassLabel); // Get the probability. double predictionProbability = riskDist[predictionDistributionIndex]; if (predictionProbability > predictedProb) { predictedProb = predictionProbability; predictedClass = predictionDistributionIndexAsClassLabel; } System.out.printf("[%2d %10s : %6.3f]", classIndex, predictionDistributionIndexAsClassLabel, predictionProbability); } m_txtRiskNB.setText(predictedClass); }
From source file:model.clasification.klasifikacijaIstanca.java
public static void main(String[] args) throws Exception { // load data// www . j a va2 s .com DataSource loader = new DataSource(fileName); Instances data = loader.getDataSet(); data.setClassIndex(data.numAttributes() - 1); // Create the Naive Bayes Classifier NaiveBayes bayesClsf = new NaiveBayes(); bayesClsf.buildClassifier(data); // output generated model // System.out.println(bayesClsf); // Test the model with the original set Evaluation eval = new Evaluation(data); eval.evaluateModel(bayesClsf, data); // Print the result as in Weka explorer String strSummary = eval.toSummaryString(); // System.out.println("=== Evaluation on training set ==="); // System.out.println("=== Summary ==="); // System.out.println(strSummary); // Get the confusion matrix System.out.println(eval.toMatrixString()); }
From source file:mulan.examples.TrainTestExperiment.java
License:Open Source License
public static void main(String[] args) { String[] methodsToCompare = { "HOMER", "BR", "CLR", "MLkNN", "MC-Copy", "IncludeLabels", "MC-Ignore", "RAkEL", "LP", "MLStacking" }; try {//ww w . j ava 2 s . c o m String path = Utils.getOption("path", args); // e.g. -path dataset/ String filestem = Utils.getOption("filestem", args); // e.g. -filestem emotions String percentage = Utils.getOption("percentage", args); // e.g. -percentage 50 (for 50%) System.out.println("Loading the dataset"); MultiLabelInstances mlDataSet = new MultiLabelInstances(path + filestem + ".arff", path + filestem + ".xml"); //split the data set into train and test Instances dataSet = mlDataSet.getDataSet(); //dataSet.randomize(new Random(1)); RemovePercentage rmvp = new RemovePercentage(); rmvp.setInvertSelection(true); rmvp.setPercentage(Double.parseDouble(percentage)); rmvp.setInputFormat(dataSet); Instances trainDataSet = Filter.useFilter(dataSet, rmvp); rmvp = new RemovePercentage(); rmvp.setPercentage(Double.parseDouble(percentage)); rmvp.setInputFormat(dataSet); Instances testDataSet = Filter.useFilter(dataSet, rmvp); MultiLabelInstances train = new MultiLabelInstances(trainDataSet, path + filestem + ".xml"); MultiLabelInstances test = new MultiLabelInstances(testDataSet, path + filestem + ".xml"); Evaluator eval = new Evaluator(); Evaluation results; for (int i = 0; i < methodsToCompare.length; i++) { if (methodsToCompare[i].equals("BR")) { System.out.println(methodsToCompare[i]); Classifier brClassifier = new NaiveBayes(); BinaryRelevance br = new BinaryRelevance(brClassifier); br.setDebug(true); br.build(train); results = eval.evaluate(br, test); System.out.println(results); } if (methodsToCompare[i].equals("LP")) { System.out.println(methodsToCompare[i]); Classifier lpBaseClassifier = new J48(); LabelPowerset lp = new LabelPowerset(lpBaseClassifier); lp.setDebug(true); lp.build(train); results = eval.evaluate(lp, test); System.out.println(results); } if (methodsToCompare[i].equals("CLR")) { System.out.println(methodsToCompare[i]); Classifier clrClassifier = new J48(); CalibratedLabelRanking clr = new CalibratedLabelRanking(clrClassifier); clr.setDebug(true); clr.build(train); results = eval.evaluate(clr, test); System.out.println(results); } if (methodsToCompare[i].equals("RAkEL")) { System.out.println(methodsToCompare[i]); MultiLabelLearner lp = new LabelPowerset(new J48()); RAkEL rakel = new RAkEL(lp); rakel.setDebug(true); rakel.build(train); results = eval.evaluate(rakel, test); System.out.println(results); } if (methodsToCompare[i].equals("MC-Copy")) { System.out.println(methodsToCompare[i]); Classifier mclClassifier = new J48(); MultiClassTransformation mcTrans = new Copy(); MultiClassLearner mcl = new MultiClassLearner(mclClassifier, mcTrans); mcl.setDebug(true); mcl.build(train); results = eval.evaluate(mcl, test); System.out.println(results); } if (methodsToCompare[i].equals("MC-Ignore")) { System.out.println(methodsToCompare[i]); Classifier mclClassifier = new J48(); MultiClassTransformation mcTrans = new Ignore(); MultiClassLearner mcl = new MultiClassLearner(mclClassifier, mcTrans); mcl.build(train); results = eval.evaluate(mcl, test); System.out.println(results); } if (methodsToCompare[i].equals("IncludeLabels")) { System.out.println(methodsToCompare[i]); Classifier ilClassifier = new J48(); IncludeLabelsClassifier il = new IncludeLabelsClassifier(ilClassifier); il.setDebug(true); il.build(train); results = eval.evaluate(il, test); System.out.println(results); } if (methodsToCompare[i].equals("MLkNN")) { System.out.println(methodsToCompare[i]); int numOfNeighbors = 10; double smooth = 1.0; MLkNN mlknn = new MLkNN(numOfNeighbors, smooth); mlknn.setDebug(true); mlknn.build(train); results = eval.evaluate(mlknn, test); System.out.println(results); } if (methodsToCompare[i].equals("HMC")) { System.out.println(methodsToCompare[i]); Classifier baseClassifier = new J48(); LabelPowerset lp = new LabelPowerset(baseClassifier); RAkEL rakel = new RAkEL(lp); HMC hmc = new HMC(rakel); hmc.build(train); results = eval.evaluate(hmc, test); System.out.println(results); } if (methodsToCompare[i].equals("HOMER")) { System.out.println(methodsToCompare[i]); Classifier baseClassifier = new SMO(); CalibratedLabelRanking learner = new CalibratedLabelRanking(baseClassifier); learner.setDebug(true); HOMER homer = new HOMER(learner, 3, HierarchyBuilder.Method.Random); homer.setDebug(true); homer.build(train); results = eval.evaluate(homer, test); System.out.println(results); } if (methodsToCompare[i].equals("MLStacking")) { System.out.println(methodsToCompare[i]); int numOfNeighbors = 10; Classifier baseClassifier = new IBk(numOfNeighbors); Classifier metaClassifier = new Logistic(); MultiLabelStacking mls = new MultiLabelStacking(baseClassifier, metaClassifier); mls.setMetaPercentage(1.0); mls.setDebug(true); mls.build(train); results = eval.evaluate(mls, test); System.out.println(results); } } } catch (Exception e) { e.printStackTrace(); } }
From source file:myclassifier.wekaCode.java
public static Classifier buildClassifier(Instances dataSet, int classifierType, boolean prune) throws Exception { Classifier classifier = null;/*from w w w .j a v a2 s . c o m*/ if (classifierType == BAYES) { classifier = new NaiveBayes(); classifier.buildClassifier(dataSet); } else if (classifierType == ID3) { classifier = new Id3(); classifier.buildClassifier(dataSet); } else if (classifierType == J48) { classifier = new J48(); classifier.buildClassifier(dataSet); } else if (classifierType == MyID3) { classifier = new MyID3(); classifier.buildClassifier(dataSet); } else if (classifierType == MyJ48) { MyJ48 j48 = new MyJ48(); j48.setPruning(prune); classifier = j48; classifier.buildClassifier(dataSet); } return classifier; }
From source file:net.sf.jclal.example.HoldOutExample.java
License:Open Source License
/** * @param args the command line arguments *//*from w w w . j a va2 s .c o m*/ public static void main(String[] args) { String fileName = "datasets/mfeat-pixel (10 clases).arff"; // The initial labeled set from the training set is randomly // selected Resample sampling = new Resample(); sampling.setNoReplacement(false); sampling.setInvertSelection(false); sampling.setPercentageInstancesToLabelled(5); // Set the scenario to use PoolBasedSamplingScenario scenario = new PoolBasedSamplingScenario(); QBestBatchMode batchMode = new QBestBatchMode(); batchMode.setBatchSize(1); scenario.setBatchMode(batchMode); //Set the oracle SimulatedOracle oracle = new SimulatedOracle(); scenario.setOracle(oracle); // Set the query strategy to use IQueryStrategy queryStrategy = new EntropySamplingQueryStrategy(); // Set the base classifier to use in the query strategy IClassifier model = new WekaClassifier(); Classifier classifier = new NaiveBayes(); ((WekaClassifier) model).setClassifier(classifier); //Set the model into the query strategy queryStrategy.setClassifier(model); //Set the query strategy into the scenario scenario.setQueryStrategy(queryStrategy); // Set the algorithm's listeners GraphicalReporterListener visual = new GraphicalReporterListener(); visual.setShowPassiveLearning(false); visual.setReportOnFile(false); visual.setShowSeparateWindow(true); visual.setReportFrequency(1); // Construct the AL algorithm ClassicalALAlgorithm algorithm = new ClassicalALAlgorithm(); //Set the listener for the algorithm algorithm.addListener(visual); //Set the maximal iteration algorithm.setMaxIteration(100); // Set the scenario into the algorithm algorithm.setScenario(scenario); //Set the evaluation method to use HoldOut method = new HoldOut(); //Set the sampling strategy into the algorithm method.setSamplingStrategy(sampling); //Set the path of the dataset method.setFileDataset(fileName); // Set the 66% of the total of instances to train the model method.setPercentageToSplit(66); RanecuFactory random = new RanecuFactory(); random.setSeed(9871234); method.setRandGenFactory(random); //Set the algorithm into the evaluation method method.setAlgorithm(algorithm); //To evaluate the algorithm method.evaluate(); }
From source file:net.sf.jclal.examples.HoldOutExample.java
License:Open Source License
/** * @param args/*from w ww . j a v a 2s. c om*/ * the command line arguments */ public static void main(String[] args) { String fileName = "datasets/iris/iris.arff"; // The initial labeled set from the training set is randomly // selected Resample sampling = new Resample(); sampling.setNoReplacement(false); sampling.setInvertSelection(false); sampling.setPercentageInstancesToLabelled(5); // Set the scenario to use PoolBasedSamplingScenario scenario = new PoolBasedSamplingScenario(); QBestBatchMode batchMode = new QBestBatchMode(); batchMode.setBatchSize(1); scenario.setBatchMode(batchMode); // Set the oracle SimulatedOracle oracle = new SimulatedOracle(); scenario.setOracle(oracle); // Set the query strategy to use IQueryStrategy queryStrategy = new EntropySamplingQueryStrategy(); // Set the base classifier to use in the query strategy IClassifier model = new WekaClassifier(); Classifier classifier = new NaiveBayes(); ((WekaClassifier) model).setClassifier(classifier); // Set the model into the query strategy queryStrategy.setClassifier(model); // Set the query strategy into the scenario scenario.setQueryStrategy(queryStrategy); // Set the algorithm's listeners GraphicalReporterListener visual = new GraphicalReporterListener(); visual.setReportOnFile(true); visual.setShowSeparateWindow(true); visual.setReportFrequency(1); // Construct the AL algorithm ClassicalALAlgorithm algorithm = new ClassicalALAlgorithm(); // Set the listener for the algorithm algorithm.addListener(visual); // Set the stop criteria MaxIteration stop1 = new MaxIteration(); stop1.setMaxIteration(45); UnlabeledSetEmpty stop2 = new UnlabeledSetEmpty(); algorithm.addStopCriterion(stop1); algorithm.addStopCriterion(stop2); // Set the scenario into the algorithm algorithm.setScenario(scenario); // Set the evaluation method to use HoldOut method = new HoldOut(); // Set the sampling strategy into the algorithm method.setSamplingStrategy(sampling); // Set the path of the dataset method.setFileDataset(fileName); // Set the 66% of the total of instances to train the model method.setPercentageToSplit(66); RanecuFactory random = new RanecuFactory(); random.setSeed(9871234); method.setRandGenFactory(random); // Set the algorithm into the evaluation method method.setAlgorithm(algorithm); // To evaluate the algorithm method.evaluate(); }
From source file:newclassifier.NewClassifier.java
public void setClassifierBayes() throws Exception { cls = new NaiveBayes(); data.setClassIndex(data.numAttributes() - 1); //cls.buildClassifier(data); }
From source file:nl.uva.expose.classification.WekaClassification.java
private void classifierTrainer(Instances trainData) throws Exception { trainData.setClassIndex(0);/*w ww . ja v a2s. co m*/ // classifier.setFilter(filter); classifier.setClassifier(new NaiveBayes()); classifier.buildClassifier(trainData); Evaluation eval = new Evaluation(trainData); eval.crossValidateModel(classifier, trainData, 5, new Random(1)); System.out.println(eval.toSummaryString()); System.out.println(eval.toClassDetailsString()); System.out.println("===== Evaluating on filtered (training) dataset done ====="); System.out.println("\n\nClassifier model:\n\n" + classifier); }
From source file:org.dkpro.similarity.algorithms.ml.ClassifierSimilarityMeasure.java
License:Open Source License
public static Classifier getClassifier(WekaClassifier classifier) throws IllegalArgumentException { try {//w w w . j a va2s. co m switch (classifier) { case NAIVE_BAYES: return new NaiveBayes(); case J48: J48 j48 = new J48(); j48.setOptions(new String[] { "-C", "0.25", "-M", "2" }); return j48; case SMO: SMO smo = new SMO(); smo.setOptions(Utils.splitOptions( "-C 1.0 -L 0.001 -P 1.0E-12 -N 0 -V -1 -W 1 -K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\"")); return smo; case LOGISTIC: Logistic logistic = new Logistic(); logistic.setOptions(Utils.splitOptions("-R 1.0E-8 -M -1")); return logistic; default: throw new IllegalArgumentException("Classifier " + classifier + " not found!"); } } catch (Exception e) { throw new IllegalArgumentException(e); } }
From source file:org.opentox.qsar.processors.trainers.classification.NaiveBayesTrainer.java
License:Open Source License
public QSARModel train(Instances data) throws QSARException { // GET A UUID AND DEFINE THE TEMPORARY FILE WHERE THE TRAINING DATA // ARE STORED IN ARFF FORMAT PRIOR TO TRAINING. final String rand = java.util.UUID.randomUUID().toString(); final String temporaryFilePath = ServerFolders.temp + "/" + rand + ".arff"; final File tempFile = new File(temporaryFilePath); // SAVE THE DATA IN THE TEMPORARY FILE try {/* w ww . j a v a 2 s . co m*/ ArffSaver dataSaver = new ArffSaver(); dataSaver.setInstances(data); dataSaver.setDestination(new FileOutputStream(tempFile)); dataSaver.writeBatch(); if (!tempFile.exists()) { throw new IOException("Temporary File was not created"); } } catch (final IOException ex) {/* * The content of the dataset cannot be * written to the destination file due to * some communication issue. */ tempFile.delete(); throw new RuntimeException( "Unexpected condition while trying to save the " + "dataset in a temporary ARFF file", ex); } NaiveBayes classifier = new NaiveBayes(); String[] generalOptions = { "-c", Integer.toString(data.classIndex() + 1), "-t", temporaryFilePath, /// Save the model in the following directory "-d", ServerFolders.models_weka + "/" + uuid }; try { Evaluation.evaluateModel(classifier, generalOptions); } catch (final Exception ex) { tempFile.delete(); throw new QSARException(Cause.XQReg350, "Unexpected condition while trying to train " + "an SVM model. Possible explanation : {" + ex.getMessage() + "}", ex); } QSARModel model = new QSARModel(); model.setParams(getParameters()); model.setCode(uuid.toString()); model.setAlgorithm(YaqpAlgorithms.NAIVE_BAYES); model.setDataset(datasetUri); model.setModelStatus(ModelStatus.UNDER_DEVELOPMENT); ArrayList<Feature> independentFeatures = new ArrayList<Feature>(); for (int i = 0; i < data.numAttributes(); i++) { Feature f = new Feature(data.attribute(i).name()); if (data.classIndex() != i) { independentFeatures.add(f); } } Feature dependentFeature = new Feature(data.classAttribute().name()); Feature predictedFeature = dependentFeature; model.setDependentFeature(dependentFeature); model.setIndependentFeatures(independentFeatures); model.setPredictionFeature(predictedFeature); tempFile.delete(); return model; }