List of usage examples for weka.core Instances add
@Override public boolean add(Instance instance)
From source file:shawn.gcbi.com.kea.main.KEAKeyphraseExtractor.java
License:Open Source License
/** * Builds the model from the files//w w w.j ava 2 s . c o m */ public void extractKeyphrases(Hashtable stems) throws Exception { Vector stats = new Vector(); // Check whether there is actually any data // = if there any files in the directory if (stems.size() == 0) { throw new Exception("Couldn't find any data!"); } m_KEAFilter.setNumPhrases(m_numPhrases); m_KEAFilter.setVocabulary(m_vocabulary); m_KEAFilter.setVocabularyFormat(m_vocabularyFormat); m_KEAFilter.setDocumentLanguage(getDocumentLanguage()); m_KEAFilter.setStemmer(m_Stemmer); m_KEAFilter.setStopwords(m_Stopwords); if (getVocabulary().equals("none")) { m_KEAFilter.m_NODEfeature = false; } else { m_KEAFilter.loadThesaurus(m_Stemmer, m_Stopwords); } FastVector atts = new FastVector(3); atts.addElement(new Attribute("doc", (FastVector) null)); atts.addElement(new Attribute("keyphrases", (FastVector) null)); atts.addElement(new Attribute("filename", (String) null)); Instances data = new Instances("keyphrase_training_data", atts, 0); if (m_KEAFilter.m_Dictionary == null) { buildGlobalDictionaries(stems); } System.err.println("-- Extracting keyphrases... "); // Extract keyphrases Enumeration elem = stems.keys(); // Enumeration over all files in the directory (now in the hash): while (elem.hasMoreElements()) { String str = (String) elem.nextElement(); double[] newInst = new double[2]; try { File txt = new File(m_dirName + "/" + str + ".txt"); InputStreamReader is; if (!m_encoding.equals("default")) { is = new InputStreamReader(new FileInputStream(txt), m_encoding); } else { is = new InputStreamReader(new FileInputStream(txt)); } StringBuffer txtStr = new StringBuffer(); int c; while ((c = is.read()) != -1) { txtStr.append((char) c); } newInst[0] = (double) data.attribute(0).addStringValue(txtStr.toString()); } catch (Exception e) { if (m_debug) { System.err.println("Can't read document " + str + ".txt"); } newInst[0] = Instance.missingValue(); } try { File key = new File(m_dirName + "/" + str + ".key"); InputStreamReader is; if (!m_encoding.equals("default")) { is = new InputStreamReader(new FileInputStream(key), m_encoding); } else { is = new InputStreamReader(new FileInputStream(key)); } StringBuffer keyStr = new StringBuffer(); int c; // keyStr = keyphrases in the str.key file // Kea assumes, that these keyphrases were assigned by the author // and evaluates extracted keyphrases againse these while ((c = is.read()) != -1) { keyStr.append((char) c); } newInst[1] = (double) data.attribute(1).addStringValue(keyStr.toString()); } catch (Exception e) { if (m_debug) { System.err.println("No existing keyphrases for stem " + str + "."); } newInst[1] = Instance.missingValue(); } data.add(new Instance(1.0, newInst)); m_KEAFilter.input(data.instance(0)); data = data.stringFreeStructure(); if (m_debug) { System.err.println("-- Document: " + str); } Instance[] topRankedInstances = new Instance[m_numPhrases]; Instance inst; // Iterating over all extracted keyphrases (inst) while ((inst = m_KEAFilter.output()) != null) { int index = (int) inst.value(m_KEAFilter.getRankIndex()) - 1; if (index < m_numPhrases) { topRankedInstances[index] = inst; } } if (m_debug) { System.err.println("-- Keyphrases and feature values:"); } FileOutputStream out = null; PrintWriter printer = null; File key = new File(m_dirName + "/" + str + ".key"); if (!key.exists()) { out = new FileOutputStream(m_dirName + "/" + str + ".key"); if (!m_encoding.equals("default")) { printer = new PrintWriter(new OutputStreamWriter(out, m_encoding)); } else { printer = new PrintWriter(out); } } double numExtracted = 0, numCorrect = 0; for (int i = 0; i < m_numPhrases; i++) { if (topRankedInstances[i] != null) { if (!topRankedInstances[i].isMissing(topRankedInstances[i].numAttributes() - 1)) { numExtracted += 1.0; } if ((int) topRankedInstances[i].value(topRankedInstances[i].numAttributes() - 1) == 1) { numCorrect += 1.0; } if (printer != null) { printer.print(topRankedInstances[i].stringValue(m_KEAFilter.getUnstemmedPhraseIndex())); System.out.print(topRankedInstances[i].stringValue(m_KEAFilter.getUnstemmedPhraseIndex())); System.out.println("\t" + Utils .doubleToString(topRankedInstances[i].value(m_KEAFilter.getProbabilityIndex()), 4)); if (m_AdditionalInfo) { printer.print("\t"); printer.print(topRankedInstances[i].stringValue(m_KEAFilter.getStemmedPhraseIndex())); printer.print("\t"); printer.print(Utils.doubleToString( topRankedInstances[i].value(m_KEAFilter.getProbabilityIndex()), 4)); } printer.println(); } if (m_debug) { System.err.println(topRankedInstances[i]); } } } if (numExtracted > 0) { if (m_debug) { System.err.println("-- " + numCorrect + " correct"); } stats.addElement(new Double(numCorrect)); } if (printer != null) { printer.flush(); printer.close(); out.close(); } } double[] st = new double[stats.size()]; for (int i = 0; i < stats.size(); i++) { st[i] = ((Double) stats.elementAt(i)).doubleValue(); } double avg = Utils.mean(st); double stdDev = Math.sqrt(Utils.variance(st)); System.err.println("Avg. number of matching keyphrases compared to existing ones : " + Utils.doubleToString(avg, 2) + " +/- " + Utils.doubleToString(stdDev, 2)); System.err.println("Based on " + stats.size() + " documents"); // m_KEAFilter.batchFinished(); }
From source file:sim.app.ubik.behaviors.sharedservices.EMClustering.java
License:Open Source License
/** * Datos de entrenamiento,crea una instancia por cada persona que actualmente est usando * un servicio con sus preferencias.//from w w w .j a va 2s.c o m * @return */ private Instances generateTrainingData() { Instances ins = new Instances("usersProfile", attributes, 1000); for (SharedService ss : slist) { for (UserInterface ui : ss.getUsers()) { ins.add(getInstance(ui)); } } return ins; }
From source file:sirius.misc.zscore.ZscoreTableModel.java
License:Open Source License
public void siriusCorrelationFiltering(final double stdDevDist, final double maxOverlapPercent, final boolean includeNegatives) { Thread thread = new Thread() { public void run() { Instances instances = ZscoreTableModel.this.posInstances; if (includeNegatives) for (int x = 0; x < ZscoreTableModel.this.negInstances.numInstances(); x++) instances.add(ZscoreTableModel.this.negInstances.instance(x)); //for now, i will ignore the sign: as in, i would care only about the absolute change of stddev (ie. |stddev|) //use an O(a*a*n) algorithm where n = num of instances and a = num of attributes MessageDialog m = new MessageDialog(null, "Progress", "0%"); for (int a = 0; a < instances.numAttributes(); a++) { int indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index(); if (instances.attribute(indexA).isNumeric() == false) continue; //for each attribute pair, check for the num of overlap percent double attibuteAStddev = instances.attributeStats(indexA).numericStats.stdDev; for (int b = a + 1; b < instances.numAttributes();) { m.update(a + "/" + instances.numAttributes()); int indexB = instances.attribute(ZscoreTableModel.this.scoreList.get(b).getName()).index(); if (instances.attribute(indexB).isNumeric() == false) { b++;//from w w w . ja va2 s . c om continue; } int numOfOverlap = 0; double attibuteBStddev = instances.attributeStats(indexB).numericStats.stdDev; for (int x = 0; x < instances.numInstances() - 1; x++) { //how do i consider an overlap? //absolute difference from the previous instance is same in stddev double attributeADifference = Math.abs( ((instances.instance(x).value(indexA) - instances.instance(x + 1).value(indexA)) / attibuteAStddev)); double attributeBDifference = Math.abs( ((instances.instance(x).value(indexB) - instances.instance(x + 1).value(indexB)) / attibuteBStddev)); if (Math.abs(attributeADifference - attributeBDifference) < stdDevDist) numOfOverlap++; } double overlapPercent = (numOfOverlap * 100) / (instances.numInstances() - 1); if (overlapPercent > maxOverlapPercent) { ZscoreTableModel.this.posInstances.deleteAttributeAt(indexB); ZscoreTableModel.this.negInstances.deleteAttributeAt(indexB); ZscoreTableModel.this.scoreList.remove(b); indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index(); } else b++; } } m.dispose(); ZscoreTableModel.this.label.setText("" + instances.numAttributes()); //compute(ZscoreTableModel.this.posInstances,ZscoreTableModel.this.negInstances); ZscoreTableModel.this.fireTableDataChanged(); } }; thread.setPriority(Thread.MIN_PRIORITY); // UI has most priority thread.start(); }
From source file:sirius.misc.zscore.ZscoreTableModel.java
License:Open Source License
public void pearsonCorrelationFiltering(final double score, final boolean includeNegatives) { Thread thread = new Thread() { public void run() { Instances instances = ZscoreTableModel.this.posInstances; if (includeNegatives) for (int x = 0; x < ZscoreTableModel.this.negInstances.numInstances(); x++) instances.add(ZscoreTableModel.this.negInstances.instance(x)); //for now, i will ignore the sign: as in, i would care only about the absolute change of stddev (ie. |stddev|) //use an O(a*a*n) algorithm where n = num of instances and a = num of attributes MessageDialog m = new MessageDialog(null, "Progress", "0%"); for (int a = 0; a < instances.numAttributes(); a++) { int indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index(); if (instances.attribute(indexA).isNumeric() == false) continue; //for each attribute pair, check for the num of overlap percent double attributeAStddev = instances.attributeStats(indexA).numericStats.stdDev; double attributeAMean = instances.attributeStats(indexA).numericStats.mean; for (int b = a + 1; b < instances.numAttributes();) { m.update(a + "/" + instances.numAttributes()); int indexB = instances.attribute(ZscoreTableModel.this.scoreList.get(b).getName()).index(); if (instances.attribute(indexB).isNumeric() == false) { b++;//from w w w .j ava2 s. c om continue; } double attributeBStddev = instances.attributeStats(indexB).numericStats.stdDev; double attributeBMean = instances.attributeStats(indexB).numericStats.mean; double nominator = 0.0; for (int x = 0; x < instances.numInstances(); x++) { nominator += ((instances.instance(x).value(indexA) - attributeAMean) * (instances.instance(x).value(indexB) - attributeBMean)); } double pScore = Math.abs( nominator / ((instances.numInstances() - 1) * attributeAStddev * attributeBStddev)); if (pScore > score) { ZscoreTableModel.this.posInstances.deleteAttributeAt(indexB); ZscoreTableModel.this.negInstances.deleteAttributeAt(indexB); ZscoreTableModel.this.scoreList.remove(b); indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index(); } else b++; } } m.dispose(); ZscoreTableModel.this.label.setText("" + instances.numAttributes()); //compute(ZscoreTableModel.this.posInstances,ZscoreTableModel.this.negInstances); ZscoreTableModel.this.fireTableDataChanged(); } }; thread.setPriority(Thread.MIN_PRIORITY); // UI has most priority thread.start(); }
From source file:sirius.misc.zscore.ZscoreTableModel.java
License:Open Source License
public void compute(final Instances posInstances, final Instances negInstances) { if (posInstances == null || negInstances == null) { JOptionPane.showMessageDialog(null, "Please load file before computing.", "Error", JOptionPane.ERROR_MESSAGE); return;// w w w. j av a 2s. c o m } if (posInstances.numAttributes() != negInstances.numAttributes()) { JOptionPane.showMessageDialog(null, "Number of attributes between the two files does not tally.", "Error", JOptionPane.ERROR_MESSAGE); return; } this.scoreList = new ArrayList<Scores>(); this.posInstances = posInstances; this.negInstances = negInstances; Thread thread = new Thread() { public void run() { MessageDialog m = new MessageDialog(null, "Progress", "0%"); int percentCount = posInstances.numAttributes() / 100; if (percentCount == 0) percentCount = 1; for (int x = 0; x < posInstances.numAttributes(); x++) { if (x % percentCount == 0) m.update(x / percentCount + "%"); if (posInstances.attribute(x).isNumeric() == false) { ZscoreTableModel.this.scoreList.add(new Scores(posInstances.attribute(x).name())); continue; } String name = posInstances.attribute(x).name(); double posMean = posInstances.attributeStats(x).numericStats.mean; double posStdDev = posInstances.attributeStats(x).numericStats.stdDev; double negMean = negInstances.attributeStats(x).numericStats.mean; double negStdDev = negInstances.attributeStats(x).numericStats.stdDev; if (negStdDev == 0) negStdDev = 0.01; double totalZScore = 0.0; int numGTZScore0_5 = 0; int numGTZScore1 = 0; int numGTZScore2 = 0; int numGTZScore3 = 0; for (int y = 0; y < posInstances.numInstances(); y++) { double zScore = Math.abs(((posInstances.instance(y).value(x) - negMean) / negStdDev)); totalZScore += zScore; if (zScore > 0.5) numGTZScore0_5++; if (zScore > 1) numGTZScore1++; if (zScore > 2) numGTZScore2++; if (zScore > 3) numGTZScore3++; } double meanZScore = totalZScore / posInstances.numInstances(); double percentGTZScore0_5 = (numGTZScore0_5 * 100) / posInstances.numInstances(); double percentGTZScore1 = (numGTZScore1 * 100) / posInstances.numInstances(); double percentGTZScore2 = (numGTZScore2 * 100) / posInstances.numInstances(); double percentGTZScore3 = (numGTZScore3 * 100) / posInstances.numInstances(); ZscoreTableModel.this.scoreList .add(new Scores(name, posMean, posStdDev, negMean, negStdDev, meanZScore, percentGTZScore0_5, percentGTZScore1, percentGTZScore2, percentGTZScore3, -1)); } try { Instances instances = new Instances(posInstances); for (int x = 0; x < negInstances.numInstances(); x++) instances.add(negInstances.instance(x)); instances.setClassIndex(instances.numAttributes() - 1); //Evaluate the attributes individually and obtain the gainRatio GainRatioAttributeEval gainRatio = new GainRatioAttributeEval(); if (instances.numAttributes() > 0) { gainRatio.buildEvaluator(instances); } for (int x = 0; x < (instances.numAttributes() - 1); x++) { ZscoreTableModel.this.scoreList.get(x).setGainRatio(gainRatio.evaluateAttribute(x)); } } catch (Exception e) { e.printStackTrace(); } Collections.sort(ZscoreTableModel.this.scoreList, new SortByMeanZScore()); fireTableDataChanged(); m.dispose(); ZscoreTableModel.this.label.setText("" + ZscoreTableModel.this.scoreList.size()); } }; thread.setPriority(Thread.MIN_PRIORITY); // UI has most priority thread.start(); }
From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java
License:Open Source License
public static Classifier xValidateClassifierOneWithNoLocationIndex(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, String classifierName, String[] classifierOptions, int folds, GraphPane myGraph, ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier, GeneticAlgorithmDialog gaDialog, GASettingsInterface gaSettings, int randomNumberForClassifier) { try {// w ww.jav a 2 s . c o m StatusPane statusPane = applicationData.getStatusPane(); if (statusPane == null) System.out.println("Null"); //else // stats long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; Classifier tempClassifier = (Classifier) Classifier.forName(classifierName, classifierOptions); Instances inst = null; if (applicationData.getDataset1Instances() != null) { inst = new Instances(applicationData.getDataset1Instances()); inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); } //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy long trainTimeStart = 0, trainTimeElapsed = 0; Classifier classifierOne = (Classifier) Classifier.forName(classifierName, classifierOptions); if (statusPane != null) statusPane.setText("Training Classifier One... May take a while... Please wait..."); //Record Start Time trainTimeStart = System.currentTimeMillis(); if (outputClassifier && gaSettings == null) classifierOne.buildClassifier(inst); //Record Total Time used to build classifier one trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; //Training Done ] if (classifierResults != null) { classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", folds + " fold cross-validation"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); } int startRandomNumber; if (gaSettings != null) startRandomNumber = gaSettings.getRandomNumber(); else startRandomNumber = 1; String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_" + randomNumberForClassifier + "_" + startRandomNumber + ".scores"; BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename)); Instances foldTrainingInstance = null; Instances foldTestingInstance = null; int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField(); int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField(); int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField(); int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField(); Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel, negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt, negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory()); FastaFormat fastaFormat; String header[] = null; String data[] = null; if (inst != null) { header = new String[inst.numInstances()]; data = new String[inst.numInstances()]; } List<FastaFormat> allPosList = new ArrayList<FastaFormat>(); List<FastaFormat> allNegList = new ArrayList<FastaFormat>(); int counter = 0; while ((fastaFormat = fastaFile.nextSequence("pos")) != null) { if (inst != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } allPosList.add(fastaFormat); } while ((fastaFormat = fastaFile.nextSequence("neg")) != null) { if (inst != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } allNegList.add(fastaFormat); } //run x folds for (int x = 0; x < folds; x++) { if (applicationData.terminateThread == true) { if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); outputCrossValidation.close(); return classifierOne; } if (statusPane != null) statusPane.setPrefix("Running Fold " + (x + 1) + ": "); if (inst != null) { foldTrainingInstance = new Instances(inst, 0); foldTestingInstance = new Instances(inst, 0); } List<FastaFormat> trainPosList = new ArrayList<FastaFormat>(); List<FastaFormat> trainNegList = new ArrayList<FastaFormat>(); List<FastaFormat> testPosList = new ArrayList<FastaFormat>(); List<FastaFormat> testNegList = new ArrayList<FastaFormat>(); //split data into training and testing //This is for normal run int testInstanceIndex[] = null; if (inst != null) testInstanceIndex = new int[(inst.numInstances() / folds) + 1]; if (gaSettings == null) { int testIndexCounter = 0; for (int y = 0; y < inst.numInstances(); y++) { if ((y % folds) == x) {//this instance is for testing foldTestingInstance.add(inst.instance(y)); testInstanceIndex[testIndexCounter] = y; testIndexCounter++; } else {//this instance is for training foldTrainingInstance.add(inst.instance(y)); } } } else { //This is for GA run for (int y = 0; y < allPosList.size(); y++) { if ((y % folds) == x) {//this instance is for testing testPosList.add(allPosList.get(y)); } else {//this instance is for training trainPosList.add(allPosList.get(y)); } } for (int y = 0; y < allNegList.size(); y++) { if ((y % folds) == x) {//this instance is for testing testNegList.add(allNegList.get(y)); } else {//this instance is for training trainNegList.add(allNegList.get(y)); } } if (gaDialog != null) foldTrainingInstance = runDAandLoadResult(applicationData, gaDialog, trainPosList, trainNegList, x + 1, startRandomNumber); else foldTrainingInstance = runDAandLoadResult(applicationData, gaSettings, trainPosList, trainNegList, x + 1, startRandomNumber); foldTrainingInstance.setClassIndex(foldTrainingInstance.numAttributes() - 1); //Reading and Storing the featureList ArrayList<Feature> featureList = new ArrayList<Feature>(); for (int y = 0; y < foldTrainingInstance.numAttributes() - 1; y++) { //-1 because class attribute must be ignored featureList.add(Feature.levelOneClassifierPane(foldTrainingInstance.attribute(y).name())); } String outputFilename; if (gaDialog != null) outputFilename = gaDialog.getOutputLocation().getText() + File.separator + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1) + ".arff"; else outputFilename = gaSettings.getOutputLocation() + File.separator + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1) + ".arff"; new GenerateFeatures(applicationData, featureList, testPosList, testNegList, outputFilename); foldTestingInstance = new Instances(new FileReader(outputFilename)); foldTestingInstance.setClassIndex(foldTestingInstance.numAttributes() - 1); } Classifier foldClassifier = tempClassifier; foldClassifier.buildClassifier(foldTrainingInstance); for (int y = 0; y < foldTestingInstance.numInstances(); y++) { if (applicationData.terminateThread == true) { if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); outputCrossValidation.close(); return classifierOne; } double[] results = foldClassifier.distributionForInstance(foldTestingInstance.instance(y)); int classIndex = foldTestingInstance.instance(y).classIndex(); String classValue = foldTestingInstance.instance(y).toString(classIndex); if (inst != null) { outputCrossValidation.write(header[testInstanceIndex[y]]); outputCrossValidation.newLine(); outputCrossValidation.write(data[testInstanceIndex[y]]); outputCrossValidation.newLine(); } else { if (y < testPosList.size()) { outputCrossValidation.write(testPosList.get(y).getHeader()); outputCrossValidation.newLine(); outputCrossValidation.write(testPosList.get(y).getSequence()); outputCrossValidation.newLine(); } else { outputCrossValidation.write(testNegList.get(y - testPosList.size()).getHeader()); outputCrossValidation.newLine(); outputCrossValidation.write(testNegList.get(y - testPosList.size()).getSequence()); outputCrossValidation.newLine(); } } if (classValue.equals("pos")) outputCrossValidation.write("pos,0=" + results[0]); else if (classValue.equals("neg")) outputCrossValidation.write("neg,0=" + results[0]); else { outputCrossValidation.close(); throw new Error("Invalid Class Type!"); } outputCrossValidation.newLine(); outputCrossValidation.flush(); } } outputCrossValidation.close(); PredictionStats classifierOneStatsOnXValidation = new PredictionStats(classifierOneFilename, range, threshold); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; if (classifierResults != null) { classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); classifierOneStatsOnXValidation.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); } applicationData.setClassifierOneStats(classifierOneStatsOnXValidation); if (myGraph != null) myGraph.setMyStats(classifierOneStatsOnXValidation); if (statusPane != null) statusPane.setText("Done!"); //Note that this will be null if GA is run though maybe it is better if i run all sequence with GA and then build the classifier but this would be a waste of time return classifierOne; } catch (Exception e) { e.printStackTrace(); JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE); return null; } }
From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java
License:Open Source License
public static Object jackKnifeClassifierOneWithNoLocationIndex(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, GenericObjectEditor m_ClassifierEditor, double ratio, GraphPane myGraph, ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier, String classifierName, String[] classifierOptions, boolean returnClassifier, int randomNumberForClassifier) { try {// ww w.j a v a 2s .c om StatusPane statusPane = applicationData.getStatusPane(); long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; Classifier tempClassifier; if (m_ClassifierEditor != null) tempClassifier = (Classifier) m_ClassifierEditor.getValue(); else tempClassifier = Classifier.forName(classifierName, classifierOptions); //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files //split the instances into positive and negative Instances posInst = new Instances(applicationData.getDataset1Instances()); posInst.setClassIndex(posInst.numAttributes() - 1); for (int x = 0; x < posInst.numInstances();) if (posInst.instance(x).stringValue(posInst.numAttributes() - 1).equalsIgnoreCase("pos")) x++; else posInst.delete(x); posInst.deleteAttributeType(Attribute.STRING); Instances negInst = new Instances(applicationData.getDataset1Instances()); negInst.setClassIndex(negInst.numAttributes() - 1); for (int x = 0; x < negInst.numInstances();) if (negInst.instance(x).stringValue(negInst.numAttributes() - 1).equalsIgnoreCase("neg")) x++; else negInst.delete(x); negInst.deleteAttributeType(Attribute.STRING); //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy long trainTimeStart = 0, trainTimeElapsed = 0; if (statusPane != null) statusPane.setText("Training Classifier One... May take a while... Please wait..."); //Record Start Time trainTimeStart = System.currentTimeMillis(); Instances fullInst = new Instances(applicationData.getDataset1Instances()); fullInst.setClassIndex(fullInst.numAttributes() - 1); Classifier classifierOne; if (m_ClassifierEditor != null) classifierOne = (Classifier) m_ClassifierEditor.getValue(); else classifierOne = Classifier.forName(classifierName, classifierOptions); if (outputClassifier) classifierOne.buildClassifier(fullInst); //Record Total Time used to build classifier one trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; //Training Done String tclassifierName; if (m_ClassifierEditor != null) tclassifierName = m_ClassifierEditor.getValue().getClass().getName(); else tclassifierName = classifierName; if (classifierResults != null) { classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", tclassifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", " Jack Knife Validation"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); } String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_" + randomNumberForClassifier + ".scores"; BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename)); //Instances foldTrainingInstance; //Instances foldTestingInstance; int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField(); int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField(); int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField(); int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField(); Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel, negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt, negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory()); FastaFormat fastaFormat; String header[] = new String[fullInst.numInstances()]; String data[] = new String[fullInst.numInstances()]; int counter = 0; while ((fastaFormat = fastaFile.nextSequence("pos")) != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } while ((fastaFormat = fastaFile.nextSequence("neg")) != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } //run jack knife validation for (int x = 0; x < fullInst.numInstances(); x++) { if (applicationData.terminateThread == true) { if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); outputCrossValidation.close(); return classifierOne; } if (statusPane != null) statusPane.setText("Running " + (x + 1) + " / " + fullInst.numInstances()); Instances trainPosInst = new Instances(posInst); Instances trainNegInst = new Instances(negInst); Instance testInst; //split data into training and testing if (x < trainPosInst.numInstances()) { testInst = posInst.instance(x); trainPosInst.delete(x); } else { testInst = negInst.instance(x - posInst.numInstances()); trainNegInst.delete(x - posInst.numInstances()); } Instances trainInstances; if (trainPosInst.numInstances() < trainNegInst.numInstances()) { trainInstances = new Instances(trainPosInst); int max = (int) (ratio * trainPosInst.numInstances()); if (ratio == -1) max = trainNegInst.numInstances(); Random rand = new Random(1); for (int y = 0; y < trainNegInst.numInstances() && y < max; y++) { int index = rand.nextInt(trainNegInst.numInstances()); trainInstances.add(trainNegInst.instance(index)); trainNegInst.delete(index); } } else { trainInstances = new Instances(trainNegInst); int max = (int) (ratio * trainNegInst.numInstances()); if (ratio == -1) max = trainPosInst.numInstances(); Random rand = new Random(1); for (int y = 0; y < trainPosInst.numInstances() && y < max; y++) { int index = rand.nextInt(trainPosInst.numInstances()); trainInstances.add(trainPosInst.instance(index)); trainPosInst.delete(index); } } Classifier foldClassifier = tempClassifier; foldClassifier.buildClassifier(trainInstances); double[] results = foldClassifier.distributionForInstance(testInst); int classIndex = testInst.classIndex(); String classValue = testInst.toString(classIndex); outputCrossValidation.write(header[x]); outputCrossValidation.newLine(); outputCrossValidation.write(data[x]); outputCrossValidation.newLine(); if (classValue.equals("pos")) outputCrossValidation.write("pos,0=" + results[0]); else if (classValue.equals("neg")) outputCrossValidation.write("neg,0=" + results[0]); else { outputCrossValidation.close(); throw new Error("Invalid Class Type!"); } outputCrossValidation.newLine(); outputCrossValidation.flush(); } outputCrossValidation.close(); PredictionStats classifierOneStatsOnJackKnife = new PredictionStats(classifierOneFilename, range, threshold); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; if (classifierResults != null) classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); //if(classifierOneDisplayTextArea != null) classifierOneStatsOnJackKnife.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); applicationData.setClassifierOneStats(classifierOneStatsOnJackKnife); if (myGraph != null) myGraph.setMyStats(classifierOneStatsOnJackKnife); if (statusPane != null) statusPane.setText("Done!"); if (returnClassifier) return classifierOne; else return classifierOneStatsOnJackKnife; } catch (Exception e) { e.printStackTrace(); JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE); return null; } }
From source file:smo2.SMO.java
License:Open Source License
/** * Method for building the classifier. Implements a one-against-one wrapper * for multi-class problems./* www .ja v a2s. c o m*/ * * @param insts * the set of training instances * @exception Exception * if the classifier can't be built successfully */ public void buildClassifier(Instances insts) throws Exception { if (!m_checksTurnedOff) { if (insts.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (insts.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException( "mySMO can't handle a numeric class! Use" + "SMOreg for performing regression."); } insts = new Instances(insts); insts.deleteWithMissingClass(); if (insts.numInstances() == 0) { throw new Exception("No training instances without a missing class!"); } /* * Removes all the instances with weight equal to 0. MUST be done * since condition (8) of Keerthi's paper is made with the assertion * Ci > 0 (See equation (3a). */ Instances data = new Instances(insts, insts.numInstances()); for (int i = 0; i < insts.numInstances(); i++) { if (insts.instance(i).weight() > 0) data.add(insts.instance(i)); } if (data.numInstances() == 0) { throw new Exception("No training instances left after removing " + "instance with either a weight null or a missing class!"); } insts = data; } m_onlyNumeric = true; if (!m_checksTurnedOff) { for (int i = 0; i < insts.numAttributes(); i++) { if (i != insts.classIndex()) { if (!insts.attribute(i).isNumeric()) { m_onlyNumeric = false; break; } } } } if (!m_checksTurnedOff) { m_Missing = new ReplaceMissingValues(); m_Missing.setInputFormat(insts); insts = Filter.useFilter(insts, m_Missing); } else { m_Missing = null; } if (!m_onlyNumeric) { m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(insts); insts = Filter.useFilter(insts, m_NominalToBinary); } else { m_NominalToBinary = null; } if (m_filterType == FILTER_STANDARDIZE) { m_Filter = new Standardize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else if (m_filterType == FILTER_NORMALIZE) { m_Filter = new Normalize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } else { m_Filter = null; } m_classIndex = insts.classIndex(); m_classAttribute = insts.classAttribute(); // Generate subsets representing each class Instances[] subsets = new Instances[insts.numClasses()]; for (int i = 0; i < insts.numClasses(); i++) { subsets[i] = new Instances(insts, insts.numInstances()); } for (int j = 0; j < insts.numInstances(); j++) { Instance inst = insts.instance(j); subsets[(int) inst.classValue()].add(inst); } for (int i = 0; i < insts.numClasses(); i++) { subsets[i].compactify(); } // Build the binary classifiers Random rand = new Random(m_randomSeed); m_classifiers = new BinarymySMO[insts.numClasses()][insts.numClasses()]; for (int i = 0; i < insts.numClasses(); i++) { for (int j = i + 1; j < insts.numClasses(); j++) { m_classifiers[i][j] = new BinarymySMO(); Instances data = new Instances(insts, insts.numInstances()); for (int k = 0; k < subsets[i].numInstances(); k++) { data.add(subsets[i].instance(k)); } for (int k = 0; k < subsets[j].numInstances(); k++) { data.add(subsets[j].instance(k)); } data.compactify(); data.randomize(rand); m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed); } } }
From source file:svmal.SVMStrategy.java
public static Instances InstancesToInstances2(Instances insts) { Instances result = new Instances(insts, 0, 0); for (int i = 0; i < insts.numInstances(); i++) { Instance orig = insts.get(i);//from w w w . j a va 2 s .c o m Instance2 inst2 = new Instance2(orig.weight(), orig.toDoubleArray()); inst2.setDataset(result); result.add(inst2); } return result; }
From source file:svmal.SVMStrategy.java
public static Instances PatternsToInstances2(Pattern[] patts) { Instances result = new Instances(patts[0].dataset(), 0, 0); for (Pattern orig : patts) { Instance2 inst2 = new Instance2(orig.weight(), orig.toDoubleArray()); inst2.setIndex(orig.id());/*from w w w.j a va2 s . c o m*/ inst2.setDataset(result); result.add(inst2); } return result; }