List of usage examples for weka.core Instances setClassIndex
public void setClassIndex(int classIndex)
From source file:com.sliit.neuralnetwork.RecurrentNN.java
public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException { System.out.println("calling trainModel"); try {/* w ww.j a v a 2 s .co m*/ System.out.println("Neural Network Training start"); loadSaveNN(modelName, false); if (model == null) { buildModel(); } File fileGeneral = new File(filePath); CSVLoader loader = new CSVLoader(); loader.setSource(fileGeneral); Instances instances = loader.getDataSet(); instances.setClassIndex(instances.numAttributes() - 1); StratifiedRemoveFolds stratified = new StratifiedRemoveFolds(); String[] options = new String[6]; options[0] = "-N"; options[1] = Integer.toString(5); options[2] = "-F"; options[3] = Integer.toString(1); options[4] = "-S"; options[5] = Integer.toString(1); stratified.setOptions(options); stratified.setInputFormat(instances); stratified.setInvertSelection(false); Instances testInstances = Filter.useFilter(instances, stratified); stratified.setInvertSelection(true); Instances trainInstances = Filter.useFilter(instances, stratified); String directory = fileGeneral.getParent(); CSVSaver saver = new CSVSaver(); File trainFile = new File(directory + "/" + "normtrainadded.csv"); File testFile = new File(directory + "/" + "normtestadded.csv"); if (trainFile.exists()) { trainFile.delete(); } trainFile.createNewFile(); if (testFile.exists()) { testFile.delete(); } testFile.createNewFile(); saver.setFile(trainFile); saver.setInstances(trainInstances); saver.writeBatch(); saver = new CSVSaver(); saver.setFile(testFile); saver.setInstances(testInstances); saver.writeBatch(); SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ","); recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile)); SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ","); testReader.initialize(new org.datavec.api.split.FileSplit(testFile)); DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator( recordReader, 2, outputs, inputsTot, false); DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator( testReader, 2, outputs, inputsTot, false); roc = new ArrayList<Map<String, Double>>(); String statMsg = ""; Evaluation evaluation; for (int i = 0; i < 100; i++) { if (i % 2 == 0) { model.fit(iterator); evaluation = model.evaluate(testIterator); } else { model.fit(testIterator); evaluation = model.evaluate(iterator); } Map<String, Double> map = new HashMap<String, Double>(); Map<Integer, Integer> falsePositives = evaluation.falsePositives(); Map<Integer, Integer> trueNegatives = evaluation.trueNegatives(); Map<Integer, Integer> truePositives = evaluation.truePositives(); Map<Integer, Integer> falseNegatives = evaluation.falseNegatives(); double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1)); double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1)); map.put("FPR", fpr); map.put("TPR", tpr); roc.add(map); statMsg = evaluation.stats(); iterator.reset(); testIterator.reset(); } loadSaveNN(modelName, true); System.out.println("ROC " + roc); return statMsg; } catch (Exception e) { e.printStackTrace(); System.out.println("Error ocuured while building neural netowrk :" + e.getMessage()); throw new NeuralException(e.getLocalizedMessage(), e); } }
From source file:com.sliit.normalize.NormalizeDataset.java
public String normalizeDataset() { System.out.println("start normalizing data"); String filePathOut = ""; try {/*from ww w . j ava2 s.co m*/ CSVLoader loader = new CSVLoader(); if (reducedDiemensionFile != null) { loader.setSource(reducedDiemensionFile); } else { if (tempFIle != null && tempFIle.exists()) { loader.setSource(tempFIle); } else { loader.setSource(csvFile); } } Instances dataInstance = loader.getDataSet(); Normalize normalize = new Normalize(); dataInstance.setClassIndex(dataInstance.numAttributes() - 1); normalize.setInputFormat(dataInstance); String directory = csvFile.getParent(); outputFile = new File(directory + "/" + "normalized" + csvFile.getName()); if (!outputFile.exists()) { outputFile.createNewFile(); } CSVSaver saver = new CSVSaver(); saver.setFile(outputFile); for (int i = 1; i < dataInstance.numInstances(); i++) { normalize.input(dataInstance.instance(i)); } normalize.batchFinished(); Instances outPut = new Instances(dataInstance, 0); for (int i = 1; i < dataInstance.numInstances(); i++) { outPut.add(normalize.output()); } Attribute attribute = dataInstance.attribute(outPut.numAttributes() - 1); for (int j = 0; j < attribute.numValues(); j++) { if (attribute.value(j).equals("normal.")) { outPut.renameAttributeValue(attribute, attribute.value(j), "0"); } else { outPut.renameAttributeValue(attribute, attribute.value(j), "1"); } } saver.setInstances(outPut); saver.writeBatch(); writeToNewFile(directory); filePathOut = directory + "norm" + csvFile.getName(); if (tempFIle != null) { tempFIle.delete(); } if (reducedDiemensionFile != null) { reducedDiemensionFile.delete(); } outputFile.delete(); } catch (IOException e) { log.error("Error occurred:" + e.getMessage()); } catch (Exception e) { log.error("Error occurred:" + e.getMessage()); } return filePathOut; }
From source file:com.sliit.normalize.NormalizeDataset.java
public int whiteningData() { System.out.println("whiteningData"); int nums = 0; try {/*from w w w .j av a 2 s.c o m*/ if (tempFIle != null && tempFIle.exists()) { csv.setSource(tempFIle); } else { csv.setSource(csvFile); } Instances instances = csv.getDataSet(); if (instances.numAttributes() > 10) { instances.setClassIndex(instances.numAttributes() - 1); RandomProjection random = new RandomProjection(); random.setDistribution( new SelectedTag(RandomProjection.GAUSSIAN, RandomProjection.TAGS_DSTRS_TYPE)); reducedDiemensionFile = new File(csvFile.getParent() + "/tempwhite.csv"); if (!reducedDiemensionFile.exists()) { reducedDiemensionFile.createNewFile(); } // CSVSaver saver = new CSVSaver(); /// saver.setFile(reducedDiemensionFile); random.setInputFormat(instances); //saver.setRetrieval(AbstractSaver.INCREMENTAL); BufferedWriter writer = new BufferedWriter(new FileWriter(reducedDiemensionFile)); for (int i = 0; i < instances.numInstances(); i++) { random.input(instances.instance(i)); random.setNumberOfAttributes(10); random.setReplaceMissingValues(true); writer.write(random.output().toString()); writer.newLine(); //saver.writeIncremental(random.output()); } writer.flush(); writer.close(); nums = random.getNumberOfAttributes(); } else { nums = instances.numAttributes(); } } catch (IOException e) { log.error("Error occurred:" + e.getMessage()); } catch (Exception e) { log.error("Error occurred:" + e.getMessage()); } return nums; }
From source file:com.sliit.rules.RuleContainer.java
public String predictionResult(String filePath) throws Exception { File testPath = new File(filePath); CSVLoader loader = new CSVLoader(); loader.setSource(testPath);/*from w w w .j a va 2 s.com*/ Instances testInstances = loader.getDataSet(); testInstances.setClassIndex(testInstances.numAttributes() - 1); Evaluation eval = new Evaluation(testInstances); eval.evaluateModel(ruleMoldel, testInstances); ArrayList<Prediction> predictions = eval.predictions(); int predictedVal = (int) predictions.get(0).predicted(); String cdetails = instances.classAttribute().value(predictedVal); return cdetails; }
From source file:com.sliit.views.DataVisualizerPanel.java
void getRocCurve() { try {/* w w w . j a v a2 s.c om*/ Instances data; data = new Instances(new BufferedReader(new FileReader(datasetPathText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.KNNView.java
void getRocCurve() { try {/*from ww w . ja va2s.c om*/ Instances data; data = new Instances(new BufferedReader(new java.io.FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); rocPanel.removeAll(); rocPanel.add(vmc, "vmc", 0); rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.sliit.views.SVMView.java
/** * draw ROC curve/* w ww .ja va2 s. c o m*/ */ void getRocCurve() { try { Instances data; data = new Instances(new BufferedReader(new FileReader(PredictorPanel.modalText.getText()))); data.setClassIndex(data.numAttributes() - 1); //train classifier Classifier cl = new NaiveBayes(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(tc.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) { cp[n] = true; } tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // rocPanel.removeAll(); // rocPanel.add(vmc, "vmc", 0); // rocPanel.revalidate(); } catch (IOException ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } catch (Exception ex) { Logger.getLogger(DataVisualizerPanel.class.getName()).log(Level.SEVERE, null, ex); } }
From source file:com.spread.experiment.tempuntilofficialrelease.ClassificationViaClustering108.java
License:Open Source License
/** * builds the classifier/*from ww w. ja v a 2s . co m*/ * * @param data the training instances * @throws Exception if something goes wrong */ @Override public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // save original header (needed for clusters to classes output) m_OriginalHeader = data.stringFreeStructure(); // remove class attribute for clusterer Instances clusterData = new Instances(data); clusterData.setClassIndex(-1); clusterData.deleteAttributeAt(data.classIndex()); m_ClusteringHeader = clusterData.stringFreeStructure(); if (m_ClusteringHeader.numAttributes() == 0) { System.err.println("Data contains only class attribute, defaulting to ZeroR model."); m_ZeroR = new ZeroR(); m_ZeroR.buildClassifier(data); } else { m_ZeroR = null; // build clusterer m_ActualClusterer = AbstractClusterer.makeCopy(m_Clusterer); m_ActualClusterer.buildClusterer(clusterData); if (!getLabelAllClusters()) { // determine classes-to-clusters mapping ClusterEvaluation eval = new ClusterEvaluation(); eval.setClusterer(m_ActualClusterer); eval.evaluateClusterer(clusterData); double[] clusterAssignments = eval.getClusterAssignments(); int[][] counts = new int[eval.getNumClusters()][m_OriginalHeader.numClasses()]; int[] clusterTotals = new int[eval.getNumClusters()]; double[] best = new double[eval.getNumClusters() + 1]; double[] current = new double[eval.getNumClusters() + 1]; for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); if (!instance.classIsMissing()) { counts[(int) clusterAssignments[i]][(int) instance.classValue()]++; clusterTotals[(int) clusterAssignments[i]]++; } } best[eval.getNumClusters()] = Double.MAX_VALUE; ClusterEvaluation.mapClasses(eval.getNumClusters(), 0, counts, clusterTotals, current, best, 0); m_ClustersToClasses = new double[best.length]; System.arraycopy(best, 0, m_ClustersToClasses, 0, best.length); } else { m_ClusterClassProbs = new double[m_ActualClusterer.numberOfClusters()][data.numClasses()]; for (int i = 0; i < data.numInstances(); i++) { Instance clusterInstance = clusterData.instance(i); Instance originalInstance = data.instance(i); if (!originalInstance.classIsMissing()) { double[] probs = m_ActualClusterer.distributionForInstance(clusterInstance); for (int j = 0; j < probs.length; j++) { m_ClusterClassProbs[j][(int) originalInstance.classValue()] += probs[j]; } } } for (int i = 0; i < m_ClusterClassProbs.length; i++) { Utils.normalize(m_ClusterClassProbs[i]); } } } }
From source file:com.tum.classifiertest.FastRfUtils.java
License:Open Source License
/** * Load a dataset into memory.// w ww . j a v a 2 s. com * * @param location the location of the dataset * * @return the dataset */ public static Instances readInstances(String location) throws Exception { Instances data = new weka.core.converters.ConverterUtils.DataSource(location).getDataSet(); if (data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1); return data; }
From source file:com.yahoo.labs.samoa.instances.SamoaToWekaInstanceConverter.java
License:Apache License
/** * Weka instances information./* w ww .j a va 2 s . c o m*/ * * @param instances the instances * @return the weka.core. instances */ public weka.core.Instances wekaInstancesInformation(Instances instances) { weka.core.Instances wekaInstances; ArrayList<weka.core.Attribute> attInfo = new ArrayList<weka.core.Attribute>(); for (int i = 0; i < instances.numAttributes(); i++) { attInfo.add(wekaAttribute(i, instances.attribute(i))); } wekaInstances = new weka.core.Instances(instances.getRelationName(), attInfo, 0); if (instances.instanceInformation.numOutputAttributes() == 1) { wekaInstances.setClassIndex(instances.classIndex()); } else { //Assign a classIndex to a MultiLabel instance for compatibility reasons wekaInstances.setClassIndex(instances.instanceInformation.numOutputAttributes() - 1); //instances.numAttributes()-1); //Last } //System.out.println(attInfo.get(3).name()); //System.out.println(attInfo.get(3).isNominal()); //System.out.println(wekaInstances.attribute(3).name()); //System.out.println(wekaInstances.attribute(3).isNominal()); return wekaInstances; }