List of usage examples for weka.classifiers Classifier buildClassifier
public abstract void buildClassifier(Instances data) throws Exception;
From source file:qa.experiment.ProcessFeatureVector.java
public String trainAndPredict(String[] processNames, String question) throws Exception { FastVector fvWekaAttribute = generateWEKAFeatureVector(processNames); Instances trainingSet = new Instances("Rel", fvWekaAttribute, bowFeature.size() + 1); trainingSet.setClassIndex(bowFeature.size()); int cnt = 0;// w ww . j a v a2 s . c o m for (int i = 0; i < arrProcessFeature.size(); i++) { String[] names = arrProcessFeature.get(i).getProcessName().split("\\|"); int sim = isNameFuzzyMatch(processNames, names); if (sim != -1) { // System.out.println("match " + arrProcessFeature.get(i).getProcessName()); ArrayList<String> featureVector = arrProcessFeature.get(i).getFeatureVectors(); for (int j = 0; j < featureVector.size(); j++) { Instance trainInstance = new Instance(bowFeature.size() + 1); String[] attrValues = featureVector.get(j).split("\t"); // System.out.println(trainInstance.numAttributes()); // System.out.println(fvWekaAttribute.size()); for (int k = 0; k < bowFeature.size(); k++) { trainInstance.setValue((Attribute) fvWekaAttribute.elementAt(k), Integer.parseInt(attrValues[k])); } trainInstance.setValue((Attribute) fvWekaAttribute.elementAt(bowFeature.size()), processNames[sim]); trainingSet.add(trainInstance); //System.out.println(cnt); cnt++; } } } Classifier cl = new NaiveBayes(); cl.buildClassifier(trainingSet); Instance inst = new Instance(bowFeature.size() + 1); //String[] tokenArr = tokens.toArray(new String[tokens.size()]); for (int j = 0; j < bowFeature.size(); j++) { List<String> tokens = slem.tokenize(question); String[] tokArr = tokens.toArray(new String[tokens.size()]); int freq = getFrequency(bowFeature.get(j), tokArr); inst.setValue((Attribute) fvWekaAttribute.elementAt(j), freq); } inst.setDataset(trainingSet); int idxMax = ArrUtil.getIdxMax(cl.distributionForInstance(inst)); return processNames[idxMax]; }
From source file:se.de.hu_berlin.informatik.faultlocalizer.machinelearn.WekaFaultLocalizer.java
License:Open Source License
/** * Builds and trains a classifier./*from ww w . ja v a2s .c o m*/ * * @param name * FQCN of the classifier * @param options * options to pass to the classifier * @param trainingSet * training set to build the classifier with * @return trained classifier */ public Classifier buildClassifier(final String name, final String[] options, final Instances trainingSet) { try { final Classifier classifier = AbstractClassifier.forName(this.classifierName, options); classifier.buildClassifier(trainingSet); return classifier; } catch (final Exception e1) { // NOCS: Weka throws only raw exceptions Log.err(this, "Unable to create classifier " + this.classifierName); throw new RuntimeException(e1); } }
From source file:sg.edu.nus.comp.nlp.ims.classifiers.CWekaModelTrainer.java
License:Open Source License
@Override public Object train(Object p_Lexelt) throws Exception { ILexelt lexelt = (ILexelt) p_Lexelt; CModelInfo retVal = new CModelInfo(); retVal.lexelt = lexelt.getID();/* w w w. j av a2 s . c om*/ retVal.statistic = lexelt.getStatistic(); if (((IStatistic) retVal.statistic).getTags().size() <= 1) { retVal.model = null; } else { String classifierName = this.m_ClassifierName; String[] args = this.m_Argvs.clone(); ILexeltWriter lexeltWriter = new CWekaSparseLexeltWriter(); Instances instances = (Instances) lexeltWriter.getInstances(lexelt); Classifier model = null; int classIdx = this.m_ClassIndex; if (classIdx < 0) { classIdx = instances.numAttributes() - 1; } instances.setClassIndex(classIdx); model = Classifier.forName(classifierName, args); model.buildClassifier(instances); retVal.model = model; } return retVal; }
From source file:sirius.clustering.main.ClustererClassificationPane.java
License:Open Source License
private void start() { //Run Classifier if (this.inputDirectoryTextField.getText().length() == 0) { JOptionPane.showMessageDialog(parent, "Please set Input Directory to where the clusterer output are!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return;// w w w. j a v a 2 s.c om } if (m_ClassifierEditor.getValue() == null) { JOptionPane.showMessageDialog(parent, "Please choose Classifier!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return; } if (validateStatsSettings(1) == false) { return; } if (this.clusteringClassificationThread == null) { startButton.setEnabled(false); stopButton.setEnabled(true); tabbedClassifierPane.setSelectedIndex(0); this.clusteringClassificationThread = (new Thread() { public void run() { //Clear the output text area levelOneClassifierOutputTextArea.setText(""); resultsTableModel.reset(); //double threshold = Double.parseDouble(classifierOneThresholdTextField.getText()); //cross-validation int numFolds; if (jackKnifeRadioButton.isSelected()) numFolds = -1; else numFolds = Integer.parseInt(foldsField.getText()); StringTokenizer st = new StringTokenizer(inputDirectoryTextField.getText(), File.separator); String filename = ""; while (st.hasMoreTokens()) { filename = st.nextToken(); } StringTokenizer st2 = new StringTokenizer(filename, "_."); numOfCluster = 0; if (st2.countTokens() >= 2) { st2.nextToken(); String numOfClusterString = st2.nextToken().replaceAll("cluster", ""); try { numOfCluster = Integer.parseInt(numOfClusterString); } catch (NumberFormatException e) { JOptionPane.showMessageDialog(parent, "Please choose the correct file! (Output from Utilize Clusterer)", "ERROR", JOptionPane.ERROR_MESSAGE); } } Classifier template = (Classifier) m_ClassifierEditor.getValue(); for (int x = 0; x <= numOfCluster && clusteringClassificationThread != null; x++) {//Test each cluster try { long totalTimeStart = 0, totalTimeElapsed = 0; totalTimeStart = System.currentTimeMillis(); statusLabel.setText("Reading in cluster" + x + " file.."); String inputFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".arff"); String outputScoreFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); BufferedWriter output = new BufferedWriter(new FileWriter(outputScoreFilename)); Instances inst = new Instances(new FileReader(inputFilename)); //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files inst.setClassIndex(inst.numAttributes() - 1); Random random = new Random(1);//Simply set to 1, shall implement the random seed option later inst.randomize(random); if (inst.attribute(inst.classIndex()).isNominal()) inst.stratify(numFolds); // for timing ClassifierResults classifierResults = new ClassifierResults(false, 0); String classifierName = m_ClassifierEditor.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", inputFilename); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", "NA"); //ArrayList<Double> resultList = new ArrayList<Double>(); if (jackKnifeRadioButton.isSelected() || numFolds > inst.numInstances() - 1) numFolds = inst.numInstances() - 1; for (int fold = 0; fold < numFolds && clusteringClassificationThread != null; fold++) {//Doing cross-validation statusLabel.setText("Cluster: " + x + " - Training Fold " + (fold + 1) + ".."); Instances train = inst.trainCV(numFolds, fold, random); Classifier current = null; try { current = Classifier.makeCopy(template); current.buildClassifier(train); Instances test = inst.testCV(numFolds, fold); statusLabel.setText("Cluster: " + x + " - Testing Fold " + (fold + 1) + ".."); for (int jj = 0; jj < test.numInstances(); jj++) { double[] result = current.distributionForInstance(test.instance(jj)); output.write("Cluster: " + x); output.newLine(); output.newLine(); output.write(test.instance(jj).stringValue(test.classAttribute()) + ",0=" + result[0]); output.newLine(); } } catch (Exception ex) { ex.printStackTrace(); statusLabel.setText("Error in cross-validation!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } output.close(); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); double threshold = validateFieldAsThreshold(classifierOneThresholdTextField.getText(), "Threshold Field", classifierOneThresholdTextField); String filename2 = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); PredictionStats classifierStats = new PredictionStats(filename2, 0, threshold); resultsTableModel.add("Cluster " + x, classifierResults, classifierStats); resultsTable.setRowSelectionInterval(x, x); computeStats(numFolds);//compute and display the results } catch (Exception e) { e.printStackTrace(); statusLabel.setText("Error in reading file!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } //end of cluster for loop resultsTableModel.add("Summary - Equal Weightage", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 1, numOfCluster + 1); computeStats(numFolds); resultsTableModel.add("Summary - Weighted Average", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 2, numOfCluster + 2); computeStats(numFolds); if (clusteringClassificationThread != null) statusLabel.setText("Done!"); else statusLabel.setText("Interrupted.."); startButton.setEnabled(true); stopButton.setEnabled(false); if (classifierOne != null) { levelOneClassifierOutputScrollPane.getVerticalScrollBar() .setValue(levelOneClassifierOutputScrollPane.getVerticalScrollBar().getMaximum()); } clusteringClassificationThread = null; } }); this.clusteringClassificationThread.setPriority(Thread.MIN_PRIORITY); this.clusteringClassificationThread.start(); } else { JOptionPane.showMessageDialog(parent, "Cannot start new job as previous job still running. Click stop to terminate previous job", "ERROR", JOptionPane.ERROR_MESSAGE); } }
From source file:sirius.trainer.step4.RunClassifier.java
License:Open Source License
public static Classifier startClassifierOne(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, GenericObjectEditor m_ClassifierEditor, GraphPane myGraph, boolean test, ClassifierResults classifierResults, int range, double threshold) { try {/* w w w .ja v a 2s. c o m*/ StatusPane statusPane = applicationData.getStatusPane(); long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; //Setting up training dataset 1 for classifier one statusPane.setText("Setting up..."); //Load Dataset1 Instances Instances inst = new Instances(applicationData.getDataset1Instances()); inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); applicationData.getDataset1Instances() .setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); // for timing long trainTimeStart = 0, trainTimeElapsed = 0; Classifier classifierOne = (Classifier) m_ClassifierEditor.getValue(); statusPane.setText("Training Classifier One... May take a while... Please wait..."); trainTimeStart = System.currentTimeMillis(); inst.deleteAttributeType(Attribute.STRING); classifierOne.buildClassifier(inst); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; String classifierName = m_ClassifierEditor.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", applicationData.getWorkingDirectory() + File.separator + "Dataset1.arff"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); if (test == false) { statusPane.setText("Classifier One Training Completed...Done..."); return classifierOne; } if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier One Training Completed"); return classifierOne; } //Running classifier one on dataset3 if (statusPane != null) statusPane.setText("Running ClassifierOne on Dataset 3.."); //Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); //Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); int positiveDataset3FromInt = applicationData.getPositiveDataset3FromField(); int positiveDataset3ToInt = applicationData.getPositiveDataset3ToField(); int negativeDataset3FromInt = applicationData.getNegativeDataset3FromField(); int negativeDataset3ToInt = applicationData.getNegativeDataset3ToField(); //Generate the header for ClassifierOne.scores on Dataset3 BufferedWriter dataset3OutputFile = new BufferedWriter(new FileWriter( applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores")); if (m_ClassifierEditor.getValue() instanceof OptionHandler) classifierName += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions()); FastaFileManipulation fastaFile = new FastaFileManipulation( applicationData.getPositiveStep1TableModel(), applicationData.getNegativeStep1TableModel(), positiveDataset3FromInt, positiveDataset3ToInt, negativeDataset3FromInt, negativeDataset3ToInt, applicationData.getWorkingDirectory()); //Reading and Storing the featureList ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute must be ignored featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name())); } //Reading the fastaFile int lineCounter = 0; String _class = "pos"; int totalDataset3PositiveInstances = positiveDataset3ToInt - positiveDataset3FromInt + 1; FastaFormat fastaFormat; while ((fastaFormat = fastaFile.nextSequence(_class)) != null) { if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier One Training Completed"); dataset3OutputFile.close(); return classifierOne; } lineCounter++;//Putting it here will mean if lineCounter is x then line == sequence x dataset3OutputFile.write(fastaFormat.getHeader()); dataset3OutputFile.newLine(); dataset3OutputFile.write(fastaFormat.getSequence()); dataset3OutputFile.newLine(); //if((lineCounter % 100) == 0){ statusPane.setText("Running Classifier One on Dataset 3.. @ " + lineCounter + " / " + applicationData.getTotalSequences(3) + " Sequences"); //} // for +1 index being -1, only make one prediction for the whole sequence if (fastaFormat.getIndexLocation() == -1) { //Should not have reached here... dataset3OutputFile.close(); throw new Exception("SHOULD NOT HAVE REACHED HERE!!"); } else {// for +1 index being non -1, make prediction on every possible position //For each sequence, you want to shift from predictPositionFrom till predictPositionTo //ie changing the +1 location //to get the scores given by classifier one so that //you can use it to train classifier two later //Doing shift from predictPositionFrom till predictPositionTo int predictPosition[]; predictPosition = fastaFormat.getPredictPositionForClassifierOne( applicationData.getLeftMostPosition(), applicationData.getRightMostPosition()); SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), predictPosition[0], predictPosition[1]); String line2; int currentPosition = predictPosition[0]; dataset3OutputFile.write(_class); while ((line2 = seq.nextShift()) != null) { Instance tempInst; tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2, featureDataArrayList.get(x), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(x, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(x, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(x, (String) obj); else { dataset3OutputFile.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(inst.numAttributes() - 1, _class); double[] results = classifierOne.distributionForInstance(tempInst); dataset3OutputFile.write("," + currentPosition + "=" + results[0]); //AHFU_DEBUG /*if(currentPosition >= setClassifierTwoUpstreamInt && currentPosition <= setClassifierTwoDownstreamInt) testClassifierTwoArff.write(results[0] + ",");*/ //AHFU_DEBUG_END currentPosition++; if (currentPosition == 0) currentPosition++; } // end of while((line2 = seq.nextShift())!=null) //AHFU_DEBUG /*testClassifierTwoArff.write(_class); testClassifierTwoArff.newLine(); testClassifierTwoArff.flush();*/ //AHFU_DEBUG_END dataset3OutputFile.newLine(); dataset3OutputFile.flush(); if (lineCounter == totalDataset3PositiveInstances) _class = "neg"; } //end of inside non -1 } // end of while((fastaFormat = fastaFile.nextSequence(_class))!=null) dataset3OutputFile.close(); PredictionStats classifierOneStatsOnBlindTest = new PredictionStats( applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores", range, threshold); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); classifierOneStatsOnBlindTest.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); applicationData.setClassifierOneStats(classifierOneStatsOnBlindTest); myGraph.setMyStats(classifierOneStatsOnBlindTest); statusPane.setText("Done!"); fastaFile.cleanUp(); return classifierOne; } catch (Exception ex) { ex.printStackTrace(); JOptionPane.showMessageDialog(parent, ex.getMessage() + "Classifier One on Blind Test Set", "Evaluate classifier", JOptionPane.ERROR_MESSAGE); return null; } }
From source file:sirius.trainer.step4.RunClassifier.java
License:Open Source License
public static Classifier startClassifierTwo(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierTwoDisplayTextArea, GenericObjectEditor m_ClassifierEditor2, Classifier classifierOne, GraphPane myGraph, boolean test, ClassifierResults classifierResults, int range, double threshold) { int arraySize = 0; int lineCount = 0; try {//from w w w .j a va 2 s. c om StatusPane statusPane = applicationData.getStatusPane(); //Initialising long totalTimeStart = System.currentTimeMillis(); Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); int positiveDataset3FromInt = applicationData.getPositiveDataset3FromField(); int positiveDataset3ToInt = applicationData.getPositiveDataset3ToField(); int negativeDataset3FromInt = applicationData.getNegativeDataset3FromField(); int negativeDataset3ToInt = applicationData.getNegativeDataset3ToField(); //Preparing Dataset2.arff to train Classifier Two statusPane.setText("Preparing Dataset2.arff..."); //This step generates Dataset2.arff if (DatasetGenerator.generateDataset2(parent, applicationData, applicationData.getSetUpstream(), applicationData.getSetDownstream(), classifierOne) == false) { //Interrupted or Error occurred return null; } //Training Classifier Two statusPane.setText("Training Classifier Two... May take a while... Please wait..."); Instances inst2 = new Instances(new BufferedReader( new FileReader(applicationData.getWorkingDirectory() + File.separator + "Dataset2.arff"))); inst2.setClassIndex(inst2.numAttributes() - 1); long trainTimeStart = 0; long trainTimeElapsed = 0; Classifier classifierTwo = (Classifier) m_ClassifierEditor2.getValue(); trainTimeStart = System.currentTimeMillis(); applicationData.setDataset2Instances(inst2); classifierTwo.buildClassifier(inst2); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; //Running Classifier Two String classifierName = m_ClassifierEditor2.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", applicationData.getWorkingDirectory() + File.separator + "Dataset2.arff"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); if (test == false) { statusPane.setText("Classifier Two Trained...Done..."); return classifierTwo; } if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier One Training Completed"); return classifierTwo; } statusPane.setText("Running Classifier Two on Dataset 3..."); //Generate the header for ClassifierTwo.scores on Dataset3 BufferedWriter classifierTwoOutput = new BufferedWriter(new FileWriter( applicationData.getWorkingDirectory() + File.separator + "ClassifierTwo.scores")); if (m_ClassifierEditor2.getValue() instanceof OptionHandler) classifierName += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor2.getValue()).getOptions()); //Generating an Instance given a sequence with the current attributes int setClassifierTwoUpstreamInt = applicationData.getSetUpstream(); int setClassifierTwoDownstreamInt = applicationData.getSetDownstream(); int classifierTwoWindowSize; if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt > 0) classifierTwoWindowSize = (setClassifierTwoUpstreamInt * -1) + setClassifierTwoDownstreamInt; else if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt < 0) classifierTwoWindowSize = (setClassifierTwoUpstreamInt - setClassifierTwoDownstreamInt - 1) * -1; else//both +ve classifierTwoWindowSize = (setClassifierTwoDownstreamInt - setClassifierTwoUpstreamInt + 1); Instances inst = applicationData.getDataset1Instances(); //NOTE: need to take care of this function; FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel, negativeStep1TableModel, positiveDataset3FromInt, positiveDataset3ToInt, negativeDataset3FromInt, negativeDataset3ToInt, applicationData.getWorkingDirectory()); //loading in all the features.. ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute must be ignored featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name())); } //Reading the fastaFile String _class = "pos"; lineCount = 0; int totalPosSequences = positiveDataset3ToInt - positiveDataset3FromInt + 1; FastaFormat fastaFormat; while ((fastaFormat = fastaFile.nextSequence(_class)) != null) { if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier Two Trained"); classifierTwoOutput.close(); return classifierTwo; } lineCount++; classifierTwoOutput.write(fastaFormat.getHeader()); classifierTwoOutput.newLine(); classifierTwoOutput.write(fastaFormat.getSequence()); classifierTwoOutput.newLine(); //if((lineCount % 100) == 0){ statusPane.setText("Running ClassifierTwo on Dataset 3...@ " + lineCount + " / " + applicationData.getTotalSequences(3) + " Sequences"); //} arraySize = fastaFormat.getArraySize(applicationData.getLeftMostPosition(), applicationData.getRightMostPosition()); //This area always generate -ve arraySize~! WHY?? Exception always occur here double scores[] = new double[arraySize]; int predictPosition[] = fastaFormat.getPredictPositionForClassifierOne( applicationData.getLeftMostPosition(), applicationData.getRightMostPosition()); //Doing shift from upstream till downstream SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), predictPosition[0], predictPosition[1]); int scoreCount = 0; String line2; while ((line2 = seq.nextShift()) != null) { Instance tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); //-1 because class attribute can be ignored for (int x = 0; x < inst.numAttributes() - 1; x++) { Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2, featureDataArrayList.get(x), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(x, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(x, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(x, (String) obj); else { classifierTwoOutput.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(inst.numAttributes() - 1, _class); //Run classifierOne double[] results = classifierOne.distributionForInstance(tempInst); scores[scoreCount++] = results[0]; } //Run classifierTwo int currentPosition = fastaFormat.getPredictionFromForClassifierTwo( applicationData.getLeftMostPosition(), applicationData.getRightMostPosition(), applicationData.getSetUpstream()); classifierTwoOutput.write(_class); for (int y = 0; y < arraySize - classifierTwoWindowSize + 1; y++) { //+1 is for the class index Instance tempInst2 = new Instance(classifierTwoWindowSize + 1); tempInst2.setDataset(inst2); for (int x = 0; x < classifierTwoWindowSize; x++) { tempInst2.setValue(x, scores[x + y]); } tempInst2.setValue(tempInst2.numAttributes() - 1, _class); double[] results = classifierTwo.distributionForInstance(tempInst2); classifierTwoOutput.write("," + currentPosition + "=" + results[0]); currentPosition++; if (currentPosition == 0) currentPosition++; } classifierTwoOutput.newLine(); classifierTwoOutput.flush(); if (lineCount == totalPosSequences) _class = "neg"; } classifierTwoOutput.close(); statusPane.setText("Done!"); PredictionStats classifierTwoStatsOnBlindTest = new PredictionStats( applicationData.getWorkingDirectory() + File.separator + "ClassifierTwo.scores", range, threshold); //display(double range) long totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); classifierTwoStatsOnBlindTest.updateDisplay(classifierResults, classifierTwoDisplayTextArea, true); applicationData.setClassifierTwoStats(classifierTwoStatsOnBlindTest); myGraph.setMyStats(classifierTwoStatsOnBlindTest); fastaFile.cleanUp(); return classifierTwo; } catch (Exception ex) { ex.printStackTrace(); JOptionPane.showMessageDialog(parent, ex.getMessage() + "Classifier Two On Blind Test Set - Check Console Output", "Evaluate classifier two", JOptionPane.ERROR_MESSAGE); System.err.println("applicationData.getLeftMostPosition(): " + applicationData.getLeftMostPosition()); System.err.println("applicationData.getRightMostPosition(): " + applicationData.getRightMostPosition()); System.err.println("arraySize: " + arraySize); System.err.println("lineCount: " + lineCount); return null; } }
From source file:sirius.trainer.step4.RunClassifier.java
License:Open Source License
public static Classifier xValidateClassifierOne(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, GenericObjectEditor m_ClassifierEditor, int folds, GraphPane myGraph, ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier) { try {/*w w w . j a va 2s . c o m*/ StatusPane statusPane = applicationData.getStatusPane(); long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; //Classifier tempClassifier = (Classifier) m_ClassifierEditor.getValue(); int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField(); int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField(); int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField(); int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField(); Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); Instances inst = new Instances(applicationData.getDataset1Instances()); inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy long trainTimeStart = 0, trainTimeElapsed = 0; Classifier classifierOne = (Classifier) m_ClassifierEditor.getValue(); statusPane.setText("Training Classifier One... May take a while... Please wait..."); //Record Start Time trainTimeStart = System.currentTimeMillis(); inst.deleteAttributeType(Attribute.STRING); if (outputClassifier) classifierOne.buildClassifier(inst); //Record Total Time used to build classifier one trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; //Training Done String classifierName = m_ClassifierEditor.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", folds + " fold cross-validation on Dataset1.arff"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); //Reading and Storing the featureList ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int y = 0; y < inst.numAttributes() - 1; y++) { featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(y).name())); } BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter( applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores")); for (int x = 0; x < folds; x++) { File trainFile = new File(applicationData.getWorkingDirectory() + File.separator + "trainingDataset1_" + (x + 1) + ".arff"); File testFile = new File(applicationData.getWorkingDirectory() + File.separator + "testingDataset1_" + (x + 1) + ".fasta"); //AHFU_DEBUG //Generate also the training file in fasta format for debugging purpose File trainFileFasta = new File(applicationData.getWorkingDirectory() + File.separator + "trainingDataset1_" + (x + 1) + ".fasta"); //AHFU_DEBUG_END //AHFU_DEBUG - This part is to generate the TestClassifierTwo.arff for use in WEKA to test classifierTwo //TestClassifierTwo.arff - predictions scores from Set Upstream Field to Set Downstream Field //Now first generate the header for TestClassifierTwo.arff BufferedWriter testClassifierTwoArff = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "TestClassifierTwo_" + (x + 1) + ".arff")); int setClassifierTwoUpstreamInt = -40; int setClassifierTwoDownstreamInt = 41; testClassifierTwoArff.write("@relation \'Used to Test Classifier Two\'"); testClassifierTwoArff.newLine(); for (int d = setClassifierTwoUpstreamInt; d <= setClassifierTwoDownstreamInt; d++) { if (d == 0) continue; testClassifierTwoArff.write("@attribute (" + d + ") numeric"); testClassifierTwoArff.newLine(); } if (positiveDataset1FromInt > 0 && negativeDataset1FromInt > 0) testClassifierTwoArff.write("@attribute Class {pos,neg}"); else if (positiveDataset1FromInt > 0 && negativeDataset1FromInt == 0) testClassifierTwoArff.write("@attribute Class {pos}"); else if (positiveDataset1FromInt == 0 && negativeDataset1FromInt > 0) testClassifierTwoArff.write("@attribute Class {neg}"); testClassifierTwoArff.newLine(); testClassifierTwoArff.newLine(); testClassifierTwoArff.write("@data"); testClassifierTwoArff.newLine(); testClassifierTwoArff.newLine(); //END of AHFU_DEBUG statusPane.setText("Building Fold " + (x + 1) + "..."); FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel, negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt, negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory()); //1) generate trainingDatasetX.arff headings BufferedWriter trainingOutputFile = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "trainingDataset1_" + (x + 1) + ".arff")); trainingOutputFile.write("@relation 'A temp file for X-validation purpose' "); trainingOutputFile.newLine(); trainingOutputFile.newLine(); trainingOutputFile.flush(); for (int y = 0; y < inst.numAttributes() - 1; y++) { if (inst.attribute(y).type() == Attribute.NUMERIC) trainingOutputFile.write("@attribute " + inst.attribute(y).name() + " numeric"); else if (inst.attribute(y).type() == Attribute.STRING) trainingOutputFile.write("@attribute " + inst.attribute(y).name() + " String"); else { testClassifierTwoArff.close(); outputCrossValidation.close(); trainingOutputFile.close(); throw new Error("Unknown type: " + inst.attribute(y).name()); } trainingOutputFile.newLine(); trainingOutputFile.flush(); } if (positiveDataset1FromInt > 0 && negativeDataset1FromInt > 0) trainingOutputFile.write("@attribute Class {pos,neg}"); else if (positiveDataset1FromInt > 0 && negativeDataset1FromInt == 0) trainingOutputFile.write("@attribute Class {pos}"); else if (positiveDataset1FromInt == 0 && negativeDataset1FromInt > 0) trainingOutputFile.write("@attribute Class {neg}"); trainingOutputFile.newLine(); trainingOutputFile.newLine(); trainingOutputFile.write("@data"); trainingOutputFile.newLine(); trainingOutputFile.newLine(); trainingOutputFile.flush(); //2) generate testingDataset1.fasta BufferedWriter testingOutputFile = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "testingDataset1_" + (x + 1) + ".fasta")); //AHFU_DEBUG //Open the IOStream for training file (fasta format) BufferedWriter trainingOutputFileFasta = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "trainingDataset1_" + (x + 1) + ".fasta")); //AHFU_DEBUG_END //Now, populating data for both the training and testing files int fastaFileLineCounter = 0; int posTestSequenceCounter = 0; int totalTestSequenceCounter = 0; //For pos sequences FastaFormat fastaFormat; while ((fastaFormat = fastaFile.nextSequence("pos")) != null) { if ((fastaFileLineCounter % folds) == x) {//This sequence for testing testingOutputFile.write(fastaFormat.getHeader()); testingOutputFile.newLine(); testingOutputFile.write(fastaFormat.getSequence()); testingOutputFile.newLine(); testingOutputFile.flush(); posTestSequenceCounter++; totalTestSequenceCounter++; } else {//for training for (int z = 0; z < inst.numAttributes() - 1; z++) { trainingOutputFile.write(GenerateArff.getMatchCount(fastaFormat, featureDataArrayList.get(z), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()) + ","); } trainingOutputFile.write("pos"); trainingOutputFile.newLine(); trainingOutputFile.flush(); //AHFU_DEBUG //Write the datas into the training file in fasta format trainingOutputFileFasta.write(fastaFormat.getHeader()); trainingOutputFileFasta.newLine(); trainingOutputFileFasta.write(fastaFormat.getSequence()); trainingOutputFileFasta.newLine(); trainingOutputFileFasta.flush(); //AHFU_DEBUG_END } fastaFileLineCounter++; } //For neg sequences fastaFileLineCounter = 0; while ((fastaFormat = fastaFile.nextSequence("neg")) != null) { if ((fastaFileLineCounter % folds) == x) {//This sequence for testing testingOutputFile.write(fastaFormat.getHeader()); testingOutputFile.newLine(); testingOutputFile.write(fastaFormat.getSequence()); testingOutputFile.newLine(); testingOutputFile.flush(); totalTestSequenceCounter++; } else {//for training for (int z = 0; z < inst.numAttributes() - 1; z++) { trainingOutputFile.write(GenerateArff.getMatchCount(fastaFormat, featureDataArrayList.get(z), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()) + ","); } trainingOutputFile.write("neg"); trainingOutputFile.newLine(); trainingOutputFile.flush(); //AHFU_DEBUG //Write the datas into the training file in fasta format trainingOutputFileFasta.write(fastaFormat.getHeader()); trainingOutputFileFasta.newLine(); trainingOutputFileFasta.write(fastaFormat.getSequence()); trainingOutputFileFasta.newLine(); trainingOutputFileFasta.flush(); //AHFU_DEBUG_END } fastaFileLineCounter++; } trainingOutputFileFasta.close(); trainingOutputFile.close(); testingOutputFile.close(); //3) train and test the classifier then store the statistics Classifier foldClassifier = (Classifier) m_ClassifierEditor.getValue(); Instances instFoldTrain = new Instances( new BufferedReader(new FileReader(applicationData.getWorkingDirectory() + File.separator + "trainingDataset1_" + (x + 1) + ".arff"))); instFoldTrain.setClassIndex(instFoldTrain.numAttributes() - 1); foldClassifier.buildClassifier(instFoldTrain); //Reading the test file statusPane.setText("Evaluating fold " + (x + 1) + ".."); BufferedReader testingInput = new BufferedReader( new FileReader(applicationData.getWorkingDirectory() + File.separator + "testingDataset1_" + (x + 1) + ".fasta")); int lineCounter = 0; String lineHeader; String lineSequence; while ((lineHeader = testingInput.readLine()) != null) { if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier One Training Completed"); testingInput.close(); testClassifierTwoArff.close(); return classifierOne; } lineSequence = testingInput.readLine(); outputCrossValidation.write(lineHeader); outputCrossValidation.newLine(); outputCrossValidation.write(lineSequence); outputCrossValidation.newLine(); lineCounter++; //For each sequence, you want to shift from upstream till downstream //ie changing the +1 location //to get the scores by classifier one so that can use it to train classifier two later //Doing shift from upstream till downstream //if(lineCounter % 100 == 0) statusPane.setText("Evaluating fold " + (x + 1) + ".. @ " + lineCounter + " / " + totalTestSequenceCounter); fastaFormat = new FastaFormat(lineHeader, lineSequence); int predictPosition[] = fastaFormat.getPredictPositionForClassifierOne( applicationData.getLeftMostPosition(), applicationData.getRightMostPosition()); SequenceManipulation seq = new SequenceManipulation(lineSequence, predictPosition[0], predictPosition[1]); int currentPosition = predictPosition[0]; String line2; if (lineCounter > posTestSequenceCounter) outputCrossValidation.write("neg"); else outputCrossValidation.write("pos"); while ((line2 = seq.nextShift()) != null) { Instance tempInst; tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); for (int i = 0; i < inst.numAttributes() - 1; i++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount(lineHeader, line2, featureDataArrayList.get(i), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(x, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(x, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(x, (String) obj); else { testingInput.close(); testClassifierTwoArff.close(); outputCrossValidation.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } if (lineCounter > posTestSequenceCounter) tempInst.setValue(inst.numAttributes() - 1, "neg"); else tempInst.setValue(inst.numAttributes() - 1, "pos"); double[] results = foldClassifier.distributionForInstance(tempInst); outputCrossValidation.write("," + currentPosition + "=" + results[0]); //AHFU_DEBUG double[] resultsDebug = classifierOne.distributionForInstance(tempInst); if (currentPosition >= setClassifierTwoUpstreamInt && currentPosition <= setClassifierTwoDownstreamInt) testClassifierTwoArff.write(resultsDebug[0] + ","); //AHFU_DEBUG_END currentPosition++; if (currentPosition == 0) currentPosition++; } //end of sequence shift outputCrossValidation.newLine(); outputCrossValidation.flush(); //AHFU_DEBUG if (lineCounter > posTestSequenceCounter) testClassifierTwoArff.write("neg"); else testClassifierTwoArff.write("pos"); testClassifierTwoArff.newLine(); testClassifierTwoArff.flush(); //AHFU_DEBUG_END } //end of reading test file outputCrossValidation.close(); testingInput.close(); testClassifierTwoArff.close(); fastaFile.cleanUp(); //NORMAL MODE //trainFile.delete(); //testFile.delete(); //NORMAL MODE END //AHFU_DEBUG MODE //testClassifierTwoArff.close(); trainFile.deleteOnExit(); testFile.deleteOnExit(); trainFileFasta.deleteOnExit(); //AHFU_DEBUG_MODE_END } //end of for loop for xvalidation PredictionStats classifierOneStatsOnXValidation = new PredictionStats( applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores", range, threshold); //display(double range) totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); classifierOneStatsOnXValidation.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); applicationData.setClassifierOneStats(classifierOneStatsOnXValidation); myGraph.setMyStats(classifierOneStatsOnXValidation); statusPane.setText("Done!"); return classifierOne; } catch (Exception e) { e.printStackTrace(); JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE); return null; } }
From source file:sirius.trainer.step4.RunClassifier.java
License:Open Source License
public static Classifier xValidateClassifierTwo(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierTwoDisplayTextArea, GenericObjectEditor m_ClassifierEditor2, Classifier classifierOne, int folds, GraphPane myGraph, ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier) { try {/* ww w . j a v a 2 s . co m*/ StatusPane statusPane = applicationData.getStatusPane(); long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; //Classifier tempClassifier = (Classifier) m_ClassifierEditor2.getValue(); final int positiveDataset2FromInt = applicationData.getPositiveDataset2FromField(); final int positiveDataset2ToInt = applicationData.getPositiveDataset2ToField(); final int negativeDataset2FromInt = applicationData.getNegativeDataset2FromField(); final int negativeDataset2ToInt = applicationData.getNegativeDataset2ToField(); final int totalDataset2Sequences = (positiveDataset2ToInt - positiveDataset2FromInt + 1) + (negativeDataset2ToInt - negativeDataset2FromInt + 1); final int classifierTwoUpstream = applicationData.getSetUpstream(); final int classifierTwoDownstream = applicationData.getSetDownstream(); Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); //Train classifier two with the full dataset first then do cross-validation to gauge its accuracy //Preparing Dataset2.arff to train Classifier Two long trainTimeStart = 0, trainTimeElapsed = 0; statusPane.setText("Preparing Dataset2.arff..."); //This step generates Dataset2.arff if (DatasetGenerator.generateDataset2(parent, applicationData, applicationData.getSetUpstream(), applicationData.getSetDownstream(), classifierOne) == false) { //Interrupted or Error occurred return null; } Instances instOfDataset2 = new Instances(new BufferedReader( new FileReader(applicationData.getWorkingDirectory() + File.separator + "Dataset2.arff"))); instOfDataset2.setClassIndex(instOfDataset2.numAttributes() - 1); applicationData.setDataset2Instances(instOfDataset2); Classifier classifierTwo = (Classifier) m_ClassifierEditor2.getValue(); statusPane.setText("Training Classifier Two... May take a while... Please wait..."); //Record Start Time trainTimeStart = System.currentTimeMillis(); if (outputClassifier) classifierTwo.buildClassifier(instOfDataset2); //Record Total Time used to build classifier one trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; //Training Done String classifierName = m_ClassifierEditor2.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", folds + " fold cross-validation on Dataset2.arff"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); Instances instOfDataset1 = new Instances(applicationData.getDataset1Instances()); instOfDataset1.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); //Reading and Storing the featureList ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int y = 0; y < instOfDataset1.numAttributes() - 1; y++) { featureDataArrayList.add(Feature.levelOneClassifierPane(instOfDataset1.attribute(y).name())); } //Generating an Instance given a sequence with the current attributes int setClassifierTwoUpstreamInt = applicationData.getSetUpstream(); int setClassifierTwoDownstreamInt = applicationData.getSetDownstream(); int classifierTwoWindowSize; if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt > 0) classifierTwoWindowSize = (setClassifierTwoUpstreamInt * -1) + setClassifierTwoDownstreamInt; else if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt < 0) classifierTwoWindowSize = (setClassifierTwoUpstreamInt - setClassifierTwoDownstreamInt - 1) * -1; else//both +ve classifierTwoWindowSize = (setClassifierTwoDownstreamInt - setClassifierTwoUpstreamInt + 1); int posTestSequenceCounter = 0; BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter( applicationData.getWorkingDirectory() + File.separator + "classifierTwo.scores")); for (int x = 0; x < folds; x++) { File trainFile = new File(applicationData.getWorkingDirectory() + File.separator + "trainingDataset2_" + (x + 1) + ".arff"); File testFile = new File(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_" + (x + 1) + ".fasta"); statusPane.setText("Preparing Training Data for Fold " + (x + 1) + ".."); FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel, negativeStep1TableModel, positiveDataset2FromInt, positiveDataset2ToInt, negativeDataset2FromInt, negativeDataset2ToInt, applicationData.getWorkingDirectory()); //1) generate trainingDataset2.arff headings BufferedWriter trainingOutputFile = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "trainingDataset2_" + (x + 1) + ".arff")); trainingOutputFile.write("@relation 'A temp file for X-validation purpose' "); trainingOutputFile.newLine(); trainingOutputFile.newLine(); trainingOutputFile.flush(); for (int y = classifierTwoUpstream; y <= classifierTwoDownstream; y++) { if (y != 0) { trainingOutputFile.write("@attribute (" + y + ") numeric"); trainingOutputFile.newLine(); trainingOutputFile.flush(); } } if (positiveDataset2FromInt > 0 && negativeDataset2FromInt > 0) trainingOutputFile.write("@attribute Class {pos,neg}"); else if (positiveDataset2FromInt > 0 && negativeDataset2FromInt == 0) trainingOutputFile.write("@attribute Class {pos}"); else if (positiveDataset2FromInt == 0 && negativeDataset2FromInt > 0) trainingOutputFile.write("@attribute Class {neg}"); trainingOutputFile.newLine(); trainingOutputFile.newLine(); trainingOutputFile.write("@data"); trainingOutputFile.newLine(); trainingOutputFile.newLine(); trainingOutputFile.flush(); //AHFU_DEBUG BufferedWriter testingOutputFileArff = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_" + (x + 1) + ".arff")); testingOutputFileArff.write("@relation 'A temp file for X-validation purpose' "); testingOutputFileArff.newLine(); testingOutputFileArff.newLine(); testingOutputFileArff.flush(); for (int y = classifierTwoUpstream; y <= classifierTwoDownstream; y++) { if (y != 0) { testingOutputFileArff.write("@attribute (" + y + ") numeric"); testingOutputFileArff.newLine(); testingOutputFileArff.flush(); } } if (positiveDataset2FromInt > 0 && negativeDataset2FromInt > 0) testingOutputFileArff.write("@attribute Class {pos,neg}"); else if (positiveDataset2FromInt > 0 && negativeDataset2FromInt == 0) testingOutputFileArff.write("@attribute Class {pos}"); else if (positiveDataset2FromInt == 0 && negativeDataset2FromInt > 0) testingOutputFileArff.write("@attribute Class {neg}"); testingOutputFileArff.newLine(); testingOutputFileArff.newLine(); testingOutputFileArff.write("@data"); testingOutputFileArff.newLine(); testingOutputFileArff.newLine(); testingOutputFileArff.flush(); //AHFU_DEBUG END //2) generate testingDataset2.fasta BufferedWriter testingOutputFile = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_" + (x + 1) + ".fasta")); //Now, populating datas for both the training and testing files int fastaFileLineCounter = 0; posTestSequenceCounter = 0; int totalTestSequenceCounter = 0; int totalTrainTestSequenceCounter = 0; FastaFormat fastaFormat; //For pos sequences while ((fastaFormat = fastaFile.nextSequence("pos")) != null) { if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier Two Trained"); outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); return classifierTwo; } totalTrainTestSequenceCounter++; //if(totalTrainTestSequenceCounter%100 == 0) statusPane.setText("Preparing Training Data for Fold " + (x + 1) + ".. @ " + totalTrainTestSequenceCounter + " / " + totalDataset2Sequences); if ((fastaFileLineCounter % folds) == x) {//This sequence is for testing testingOutputFile.write(fastaFormat.getHeader()); testingOutputFile.newLine(); testingOutputFile.write(fastaFormat.getSequence()); testingOutputFile.newLine(); testingOutputFile.flush(); posTestSequenceCounter++; totalTestSequenceCounter++; //AHFU DEBUG SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), classifierTwoUpstream, classifierTwoDownstream); String line2; while ((line2 = seq.nextShift()) != null) { Instance tempInst = new Instance(instOfDataset1.numAttributes()); tempInst.setDataset(instOfDataset1); //-1 because class attribute can be ignored for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) { Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2, featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(w, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(w, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(w, (String) obj); else { outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(tempInst.numAttributes() - 1, "pos"); double[] results = classifierOne.distributionForInstance(tempInst); testingOutputFileArff.write(results[0] + ","); } testingOutputFileArff.write("pos"); testingOutputFileArff.newLine(); testingOutputFileArff.flush(); //AHFU DEBUG END } else {//This sequence is for training SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), classifierTwoUpstream, classifierTwoDownstream); String line2; while ((line2 = seq.nextShift()) != null) { Instance tempInst = new Instance(instOfDataset1.numAttributes()); tempInst.setDataset(instOfDataset1); //-1 because class attribute can be ignored for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) { Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2, featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(w, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(w, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(w, (String) obj); else { outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(tempInst.numAttributes() - 1, "pos"); double[] results = classifierOne.distributionForInstance(tempInst); trainingOutputFile.write(results[0] + ","); } trainingOutputFile.write("pos"); trainingOutputFile.newLine(); trainingOutputFile.flush(); } fastaFileLineCounter++; } //For neg sequences fastaFileLineCounter = 0; while ((fastaFormat = fastaFile.nextSequence("neg")) != null) { if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier Two Trained"); outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); return classifierTwo; } totalTrainTestSequenceCounter++; //if(totalTrainTestSequenceCounter%100 == 0) statusPane.setText("Preparing Training Data for Fold " + (x + 1) + ".. @ " + totalTrainTestSequenceCounter + " / " + totalDataset2Sequences); if ((fastaFileLineCounter % folds) == x) {//This sequence is for testing testingOutputFile.write(fastaFormat.getHeader()); testingOutputFile.newLine(); testingOutputFile.write(fastaFormat.getSequence()); testingOutputFile.newLine(); testingOutputFile.flush(); totalTestSequenceCounter++; //AHFU DEBUG SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), classifierTwoUpstream, classifierTwoDownstream); String line2; while ((line2 = seq.nextShift()) != null) { Instance tempInst = new Instance(instOfDataset1.numAttributes()); tempInst.setDataset(instOfDataset1); //-1 because class attribute can be ignored for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) { Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2, featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(w, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(w, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(w, (String) obj); else { outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(tempInst.numAttributes() - 1, "pos");//pos or neg does not matter here - not used double[] results = classifierOne.distributionForInstance(tempInst); testingOutputFileArff.write(results[0] + ","); } testingOutputFileArff.write("neg"); testingOutputFileArff.newLine(); testingOutputFileArff.flush(); //AHFU DEBUG END } else {//This sequence is for training SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), classifierTwoUpstream, classifierTwoDownstream); String line2; while ((line2 = seq.nextShift()) != null) { Instance tempInst = new Instance(instOfDataset1.numAttributes()); tempInst.setDataset(instOfDataset1); //-1 because class attribute can be ignored for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) { Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2, featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(w, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(w, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(w, (String) obj); else { outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(tempInst.numAttributes() - 1, "pos");//pos or neg does not matter here - not used double[] results = classifierOne.distributionForInstance(tempInst); trainingOutputFile.write(results[0] + ","); } trainingOutputFile.write("neg"); trainingOutputFile.newLine(); trainingOutputFile.flush(); } fastaFileLineCounter++; } trainingOutputFile.close(); testingOutputFile.close(); //AHFU_DEBUG testingOutputFileArff.close(); //AHFU DEBUG END //3) train and test classifier two then store the statistics statusPane.setText("Building Fold " + (x + 1) + ".."); //open an input stream to the arff file BufferedReader trainingInput = new BufferedReader( new FileReader(applicationData.getWorkingDirectory() + File.separator + "trainingDataset2_" + (x + 1) + ".arff")); //getting ready to train a foldClassifier using arff file Instances instOfTrainingDataset2 = new Instances( new BufferedReader(new FileReader(applicationData.getWorkingDirectory() + File.separator + "trainingDataset2_" + (x + 1) + ".arff"))); instOfTrainingDataset2.setClassIndex(instOfTrainingDataset2.numAttributes() - 1); Classifier foldClassifier = (Classifier) m_ClassifierEditor2.getValue(); foldClassifier.buildClassifier(instOfTrainingDataset2); trainingInput.close(); //Reading the test file statusPane.setText("Evaluating fold " + (x + 1) + ".."); BufferedReader testingInput = new BufferedReader( new FileReader(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_" + (x + 1) + ".fasta")); int lineCounter = 0; String lineHeader; String lineSequence; while ((lineHeader = testingInput.readLine()) != null) { if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier Two Not Trained"); outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); testingInput.close(); return classifierTwo; } lineSequence = testingInput.readLine(); outputCrossValidation.write(lineHeader); outputCrossValidation.newLine(); outputCrossValidation.write(lineSequence); outputCrossValidation.newLine(); lineCounter++; fastaFormat = new FastaFormat(lineHeader, lineSequence); int arraySize = fastaFormat.getArraySize(applicationData.getLeftMostPosition(), applicationData.getRightMostPosition()); double scores[] = new double[arraySize]; int predictPosition[] = fastaFormat.getPredictPositionForClassifierOne( applicationData.getLeftMostPosition(), applicationData.getRightMostPosition()); //For each sequence, you want to shift from upstream till downstream //ie changing the +1 location //to get the scores by classifier one so that can use it to train classifier two later //Doing shift from upstream till downstream //if(lineCounter % 100 == 0) statusPane.setText("Evaluating fold " + (x + 1) + ".. @ " + lineCounter + " / " + totalTestSequenceCounter); SequenceManipulation seq = new SequenceManipulation(lineSequence, predictPosition[0], predictPosition[1]); int scoreCount = 0; String line2; while ((line2 = seq.nextShift()) != null) { Instance tempInst = new Instance(instOfDataset1.numAttributes()); tempInst.setDataset(instOfDataset1); for (int i = 0; i < instOfDataset1.numAttributes() - 1; i++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount(lineHeader, line2, featureDataArrayList.get(i), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(i, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(i, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(i, (String) obj); else { outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); testingInput.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } if (lineCounter > posTestSequenceCounter) {//for neg tempInst.setValue(tempInst.numAttributes() - 1, "neg"); } else { tempInst.setValue(tempInst.numAttributes() - 1, "pos"); } double[] results = classifierOne.distributionForInstance(tempInst); scores[scoreCount++] = results[0]; } //end of sequence shift //Run classifierTwo int currentPosition = fastaFormat.getPredictionFromForClassifierTwo( applicationData.getLeftMostPosition(), applicationData.getRightMostPosition(), applicationData.getSetUpstream()); if (lineCounter > posTestSequenceCounter)//neg outputCrossValidation.write("neg"); else outputCrossValidation.write("pos"); for (int y = 0; y < arraySize - classifierTwoWindowSize + 1; y++) { //+1 is for the class index Instance tempInst2 = new Instance(classifierTwoWindowSize + 1); tempInst2.setDataset(instOfTrainingDataset2); for (int l = 0; l < classifierTwoWindowSize; l++) { tempInst2.setValue(l, scores[l + y]); } if (lineCounter > posTestSequenceCounter)//for neg tempInst2.setValue(tempInst2.numAttributes() - 1, "neg"); else//for pos tempInst2.setValue(tempInst2.numAttributes() - 1, "pos"); double[] results = foldClassifier.distributionForInstance(tempInst2); outputCrossValidation.write("," + currentPosition + "=" + results[0]); currentPosition++; if (currentPosition == 0) currentPosition++; } outputCrossValidation.newLine(); outputCrossValidation.flush(); } //end of reading test file outputCrossValidation.close(); testingOutputFileArff.close(); testingOutputFile.close(); trainingOutputFile.close(); testingInput.close(); fastaFile.cleanUp(); //AHFU_DEBUG trainFile.deleteOnExit(); testFile.deleteOnExit(); //NORMAL MODE //trainFile.delete(); //testFile.delete(); } //end of for loop for xvalidation PredictionStats classifierTwoStatsOnXValidation = new PredictionStats( applicationData.getWorkingDirectory() + File.separator + "classifierTwo.scores", range, threshold); //display(double range) totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); classifierTwoStatsOnXValidation.updateDisplay(classifierResults, classifierTwoDisplayTextArea, true); applicationData.setClassifierTwoStats(classifierTwoStatsOnXValidation); myGraph.setMyStats(classifierTwoStatsOnXValidation); statusPane.setText("Done!"); return classifierTwo; } catch (Exception e) { e.printStackTrace(); JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE); return null; } }
From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java
License:Open Source License
public static Object startClassifierOneWithNoLocationIndex(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, GraphPane myGraph, boolean test, ClassifierResults classifierResults, int range, double threshold, String classifierName, String[] classifierOptions, boolean returnClassifier, GeneticAlgorithmDialog gaDialog, int randomNumberForClassifier) { try {//from w w w.j a va 2s . co m if (gaDialog != null) { //Run GA then load the result maxMCCFeatures into applicationData->Dataset1Instances int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField(); int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField(); int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField(); int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField(); FastaFileManipulation fastaFile = new FastaFileManipulation( applicationData.getPositiveStep1TableModel(), applicationData.getNegativeStep1TableModel(), positiveDataset1FromInt, positiveDataset1ToInt, negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory()); FastaFormat fastaFormat; List<FastaFormat> posFastaList = new ArrayList<FastaFormat>(); List<FastaFormat> negFastaList = new ArrayList<FastaFormat>(); while ((fastaFormat = fastaFile.nextSequence("pos")) != null) { posFastaList.add(fastaFormat); } while ((fastaFormat = fastaFile.nextSequence("neg")) != null) { negFastaList.add(fastaFormat); } applicationData.setDataset1Instances( runDAandLoadResult(applicationData, gaDialog, posFastaList, negFastaList)); } StatusPane statusPane = applicationData.getStatusPane(); long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; //Setting up training data set 1 for classifier one if (statusPane != null) statusPane.setText("Setting up..."); //Load Dataset1 Instances Instances inst = new Instances(applicationData.getDataset1Instances()); inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); applicationData.getDataset1Instances() .setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); // for recording of time long trainTimeStart = 0, trainTimeElapsed = 0; Classifier classifierOne = Classifier.forName(classifierName, classifierOptions); /*//Used to show the classifierName and options so that I can use them for qsub System.out.println(classifierName); String[] optionString = classifierOne.getOptions(); for(int x = 0; x < optionString.length; x++) System.out.println(optionString[x]);*/ if (statusPane != null) statusPane.setText("Training Classifier One... May take a while... Please wait..."); //Record Start Time trainTimeStart = System.currentTimeMillis(); //Train Classifier One inst.deleteAttributeType(Attribute.STRING); classifierOne.buildClassifier(inst); //Record Total Time used to build classifier one trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; if (classifierResults != null) { classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", applicationData.getWorkingDirectory() + File.separator + "Dataset1.arff"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); } if (test == false) { //If Need Not Test option is selected if (statusPane != null) statusPane.setText("Done!"); return classifierOne; } if (applicationData.terminateThread == true) { //If Stop button is pressed if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); return classifierOne; } //Running classifier one on dataset3 if (statusPane != null) statusPane.setText("Running ClassifierOne on Dataset 3.."); int positiveDataset3FromInt = applicationData.getPositiveDataset3FromField(); int positiveDataset3ToInt = applicationData.getPositiveDataset3ToField(); int negativeDataset3FromInt = applicationData.getNegativeDataset3FromField(); int negativeDataset3ToInt = applicationData.getNegativeDataset3ToField(); //Generate the header for ClassifierOne.scores on Dataset3 String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_" + randomNumberForClassifier + ".scores"; BufferedWriter dataset3OutputFile = new BufferedWriter(new FileWriter(classifierOneFilename)); FastaFileManipulation fastaFile = new FastaFileManipulation( applicationData.getPositiveStep1TableModel(), applicationData.getNegativeStep1TableModel(), positiveDataset3FromInt, positiveDataset3ToInt, negativeDataset3FromInt, negativeDataset3ToInt, applicationData.getWorkingDirectory()); //Reading and Storing the featureList ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute must be ignored featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name())); } //Reading the fastaFile int lineCounter = 0; String _class = "pos"; int totalDataset3PositiveInstances = positiveDataset3ToInt - positiveDataset3FromInt + 1; FastaFormat fastaFormat; while ((fastaFormat = fastaFile.nextSequence(_class)) != null) { if (applicationData.terminateThread == true) { if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); dataset3OutputFile.close(); return classifierOne; } dataset3OutputFile.write(fastaFormat.getHeader()); dataset3OutputFile.newLine(); dataset3OutputFile.write(fastaFormat.getSequence()); dataset3OutputFile.newLine(); lineCounter++;//Putting it here will mean if lineCounter is x then line == sequence x dataset3OutputFile.flush(); if (statusPane != null) statusPane.setText("Running Classifier One on Dataset 3.. @ " + lineCounter + " / " + applicationData.getTotalSequences(3) + " Sequences"); Instance tempInst; tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount(fastaFormat, featureDataArrayList.get(x), applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(x, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(x, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(x, (String) obj); else { dataset3OutputFile.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(inst.numAttributes() - 1, _class); double[] results = classifierOne.distributionForInstance(tempInst); dataset3OutputFile.write(_class + ",0=" + results[0]); dataset3OutputFile.newLine(); dataset3OutputFile.flush(); if (lineCounter == totalDataset3PositiveInstances) _class = "neg"; } dataset3OutputFile.close(); //Display Statistics by reading the ClassifierOne.scores PredictionStats classifierOneStatsOnBlindTest = new PredictionStats(classifierOneFilename, range, threshold); //display(double range) totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; if (classifierResults != null) { classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); classifierOneStatsOnBlindTest.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); } else classifierOneStatsOnBlindTest.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); applicationData.setClassifierOneStats(classifierOneStatsOnBlindTest); if (myGraph != null) myGraph.setMyStats(classifierOneStatsOnBlindTest); if (statusPane != null) statusPane.setText("Done!"); fastaFile.cleanUp(); if (returnClassifier) return classifierOne; else return classifierOneStatsOnBlindTest; } catch (Exception ex) { ex.printStackTrace(); JOptionPane.showMessageDialog(parent, ex.getMessage(), "Evaluate classifier", JOptionPane.ERROR_MESSAGE); return null; } }
From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java
License:Open Source License
public static Classifier xValidateClassifierOneWithNoLocationIndex(JInternalFrame parent, ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, String classifierName, String[] classifierOptions, int folds, GraphPane myGraph, ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier, GeneticAlgorithmDialog gaDialog, GASettingsInterface gaSettings, int randomNumberForClassifier) { try {//w ww .ja v a 2s . c o m StatusPane statusPane = applicationData.getStatusPane(); if (statusPane == null) System.out.println("Null"); //else // stats long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed; Classifier tempClassifier = (Classifier) Classifier.forName(classifierName, classifierOptions); Instances inst = null; if (applicationData.getDataset1Instances() != null) { inst = new Instances(applicationData.getDataset1Instances()); inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1); } //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy long trainTimeStart = 0, trainTimeElapsed = 0; Classifier classifierOne = (Classifier) Classifier.forName(classifierName, classifierOptions); if (statusPane != null) statusPane.setText("Training Classifier One... May take a while... Please wait..."); //Record Start Time trainTimeStart = System.currentTimeMillis(); if (outputClassifier && gaSettings == null) classifierOne.buildClassifier(inst); //Record Total Time used to build classifier one trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; //Training Done ] if (classifierResults != null) { classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", folds + " fold cross-validation"); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds"); } int startRandomNumber; if (gaSettings != null) startRandomNumber = gaSettings.getRandomNumber(); else startRandomNumber = 1; String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_" + randomNumberForClassifier + "_" + startRandomNumber + ".scores"; BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename)); Instances foldTrainingInstance = null; Instances foldTestingInstance = null; int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField(); int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField(); int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField(); int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField(); Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel(); Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel(); FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel, negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt, negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory()); FastaFormat fastaFormat; String header[] = null; String data[] = null; if (inst != null) { header = new String[inst.numInstances()]; data = new String[inst.numInstances()]; } List<FastaFormat> allPosList = new ArrayList<FastaFormat>(); List<FastaFormat> allNegList = new ArrayList<FastaFormat>(); int counter = 0; while ((fastaFormat = fastaFile.nextSequence("pos")) != null) { if (inst != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } allPosList.add(fastaFormat); } while ((fastaFormat = fastaFile.nextSequence("neg")) != null) { if (inst != null) { header[counter] = fastaFormat.getHeader(); data[counter] = fastaFormat.getSequence(); counter++; } allNegList.add(fastaFormat); } //run x folds for (int x = 0; x < folds; x++) { if (applicationData.terminateThread == true) { if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); outputCrossValidation.close(); return classifierOne; } if (statusPane != null) statusPane.setPrefix("Running Fold " + (x + 1) + ": "); if (inst != null) { foldTrainingInstance = new Instances(inst, 0); foldTestingInstance = new Instances(inst, 0); } List<FastaFormat> trainPosList = new ArrayList<FastaFormat>(); List<FastaFormat> trainNegList = new ArrayList<FastaFormat>(); List<FastaFormat> testPosList = new ArrayList<FastaFormat>(); List<FastaFormat> testNegList = new ArrayList<FastaFormat>(); //split data into training and testing //This is for normal run int testInstanceIndex[] = null; if (inst != null) testInstanceIndex = new int[(inst.numInstances() / folds) + 1]; if (gaSettings == null) { int testIndexCounter = 0; for (int y = 0; y < inst.numInstances(); y++) { if ((y % folds) == x) {//this instance is for testing foldTestingInstance.add(inst.instance(y)); testInstanceIndex[testIndexCounter] = y; testIndexCounter++; } else {//this instance is for training foldTrainingInstance.add(inst.instance(y)); } } } else { //This is for GA run for (int y = 0; y < allPosList.size(); y++) { if ((y % folds) == x) {//this instance is for testing testPosList.add(allPosList.get(y)); } else {//this instance is for training trainPosList.add(allPosList.get(y)); } } for (int y = 0; y < allNegList.size(); y++) { if ((y % folds) == x) {//this instance is for testing testNegList.add(allNegList.get(y)); } else {//this instance is for training trainNegList.add(allNegList.get(y)); } } if (gaDialog != null) foldTrainingInstance = runDAandLoadResult(applicationData, gaDialog, trainPosList, trainNegList, x + 1, startRandomNumber); else foldTrainingInstance = runDAandLoadResult(applicationData, gaSettings, trainPosList, trainNegList, x + 1, startRandomNumber); foldTrainingInstance.setClassIndex(foldTrainingInstance.numAttributes() - 1); //Reading and Storing the featureList ArrayList<Feature> featureList = new ArrayList<Feature>(); for (int y = 0; y < foldTrainingInstance.numAttributes() - 1; y++) { //-1 because class attribute must be ignored featureList.add(Feature.levelOneClassifierPane(foldTrainingInstance.attribute(y).name())); } String outputFilename; if (gaDialog != null) outputFilename = gaDialog.getOutputLocation().getText() + File.separator + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1) + ".arff"; else outputFilename = gaSettings.getOutputLocation() + File.separator + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1) + ".arff"; new GenerateFeatures(applicationData, featureList, testPosList, testNegList, outputFilename); foldTestingInstance = new Instances(new FileReader(outputFilename)); foldTestingInstance.setClassIndex(foldTestingInstance.numAttributes() - 1); } Classifier foldClassifier = tempClassifier; foldClassifier.buildClassifier(foldTrainingInstance); for (int y = 0; y < foldTestingInstance.numInstances(); y++) { if (applicationData.terminateThread == true) { if (statusPane != null) statusPane.setText("Interrupted - Classifier One Training Completed"); outputCrossValidation.close(); return classifierOne; } double[] results = foldClassifier.distributionForInstance(foldTestingInstance.instance(y)); int classIndex = foldTestingInstance.instance(y).classIndex(); String classValue = foldTestingInstance.instance(y).toString(classIndex); if (inst != null) { outputCrossValidation.write(header[testInstanceIndex[y]]); outputCrossValidation.newLine(); outputCrossValidation.write(data[testInstanceIndex[y]]); outputCrossValidation.newLine(); } else { if (y < testPosList.size()) { outputCrossValidation.write(testPosList.get(y).getHeader()); outputCrossValidation.newLine(); outputCrossValidation.write(testPosList.get(y).getSequence()); outputCrossValidation.newLine(); } else { outputCrossValidation.write(testNegList.get(y - testPosList.size()).getHeader()); outputCrossValidation.newLine(); outputCrossValidation.write(testNegList.get(y - testPosList.size()).getSequence()); outputCrossValidation.newLine(); } } if (classValue.equals("pos")) outputCrossValidation.write("pos,0=" + results[0]); else if (classValue.equals("neg")) outputCrossValidation.write("neg,0=" + results[0]); else { outputCrossValidation.close(); throw new Error("Invalid Class Type!"); } outputCrossValidation.newLine(); outputCrossValidation.flush(); } } outputCrossValidation.close(); PredictionStats classifierOneStatsOnXValidation = new PredictionStats(classifierOneFilename, range, threshold); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; if (classifierResults != null) { classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); classifierOneStatsOnXValidation.updateDisplay(classifierResults, classifierOneDisplayTextArea, true); } applicationData.setClassifierOneStats(classifierOneStatsOnXValidation); if (myGraph != null) myGraph.setMyStats(classifierOneStatsOnXValidation); if (statusPane != null) statusPane.setText("Done!"); //Note that this will be null if GA is run though maybe it is better if i run all sequence with GA and then build the classifier but this would be a waste of time return classifierOne; } catch (Exception e) { e.printStackTrace(); JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE); return null; } }