List of usage examples for weka.classifiers Classifier distributionForInstance
public double[] distributionForInstance(Instance instance) throws Exception;
From source file:org.opentox.qsar.processors.predictors.SimplePredictor.java
License:Open Source License
/** * Perform the prediction which is based on the serialized model file on the server. * @param data/*from w w w. j a v a 2 s. co m*/ * Input data for with respect to which the predicitons are calculated * @return * A dataset containing the compounds submitted along with their predicted values. * @throws QSARException * In case the prediction (as a whole) is not feasible. If the prediction is not * feasible for a single instance, the prediction is set to <code>?</code> (unknown/undefined/missing). * If the prediction is not feasible for all instances, an exception (QSARException) is thrown. */ @Override public Instances predict(final Instances data) throws QSARException { Instances dataClone = new Instances(data); /** * IMPORTANT! * String attributes have to be removed from the dataset before * applying the prediciton */ dataClone = new AttributeCleanup(ATTRIBUTE_TYPE.string).filter(dataClone); /** * Set the class attribute of the incoming data to any arbitrary attribute * (Choose the last for instance). */ dataClone.setClass(dataClone.attribute(model.getDependentFeature().getURI())); /** * * Create the Instances that will host the predictions. This object contains * only two attributes: the compound_uri and the target feature of the model. */ Instances predictions = null; FastVector attributes = new FastVector(); final Attribute compoundAttribute = new Attribute("compound_uri", (FastVector) null); final Attribute targetAttribute = dataClone.classAttribute(); attributes.addElement(compoundAttribute); attributes.addElement(targetAttribute); predictions = new Instances("predictions", attributes, 0); predictions.setClassIndex(1); Instance predictionInstance = new Instance(2); try { final Classifier cls = (Classifier) SerializationHelper.read(filePath); for (int i = 0; i < data.numInstances(); i++) { try { String currentCompound = data.instance(i).stringValue(0); predictionInstance.setValue(compoundAttribute, currentCompound); if (targetAttribute.type() == Attribute.NUMERIC) { double clsLabel = cls.classifyInstance(dataClone.instance(i)); predictionInstance.setValue(targetAttribute, clsLabel); } else if (targetAttribute.type() == Attribute.NOMINAL) { double[] clsLable = cls.distributionForInstance(dataClone.instance(i)); int indexForNominalElement = maxInArray(clsLable).getPosition(); Enumeration nominalValues = targetAttribute.enumerateValues(); int counter = 0; String nomValue = ""; while (nominalValues.hasMoreElements()) { if (counter == indexForNominalElement) { nomValue = nominalValues.nextElement().toString(); break; } counter++; } predictionInstance.setValue(targetAttribute, nomValue); predictionInstance.setValue(targetAttribute, cls.classifyInstance(dataClone.instance(i))); } predictions.add(predictionInstance); } catch (Exception ex) { System.out.println(ex); } } } catch (Exception ex) { } return predictions; }
From source file:qa.experiment.ProcessFeatureVector.java
public String trainAndPredict(String[] processNames, String question) throws Exception { FastVector fvWekaAttribute = generateWEKAFeatureVector(processNames); Instances trainingSet = new Instances("Rel", fvWekaAttribute, bowFeature.size() + 1); trainingSet.setClassIndex(bowFeature.size()); int cnt = 0;// ww w.j ava 2s.c o m for (int i = 0; i < arrProcessFeature.size(); i++) { String[] names = arrProcessFeature.get(i).getProcessName().split("\\|"); int sim = isNameFuzzyMatch(processNames, names); if (sim != -1) { // System.out.println("match " + arrProcessFeature.get(i).getProcessName()); ArrayList<String> featureVector = arrProcessFeature.get(i).getFeatureVectors(); for (int j = 0; j < featureVector.size(); j++) { Instance trainInstance = new Instance(bowFeature.size() + 1); String[] attrValues = featureVector.get(j).split("\t"); // System.out.println(trainInstance.numAttributes()); // System.out.println(fvWekaAttribute.size()); for (int k = 0; k < bowFeature.size(); k++) { trainInstance.setValue((Attribute) fvWekaAttribute.elementAt(k), Integer.parseInt(attrValues[k])); } trainInstance.setValue((Attribute) fvWekaAttribute.elementAt(bowFeature.size()), processNames[sim]); trainingSet.add(trainInstance); //System.out.println(cnt); cnt++; } } } Classifier cl = new NaiveBayes(); cl.buildClassifier(trainingSet); Instance inst = new Instance(bowFeature.size() + 1); //String[] tokenArr = tokens.toArray(new String[tokens.size()]); for (int j = 0; j < bowFeature.size(); j++) { List<String> tokens = slem.tokenize(question); String[] tokArr = tokens.toArray(new String[tokens.size()]); int freq = getFrequency(bowFeature.get(j), tokArr); inst.setValue((Attribute) fvWekaAttribute.elementAt(j), freq); } inst.setDataset(trainingSet); int idxMax = ArrUtil.getIdxMax(cl.distributionForInstance(inst)); return processNames[idxMax]; }
From source file:se.de.hu_berlin.informatik.faultlocalizer.machinelearn.WekaFaultLocalizer.java
License:Open Source License
@Override public SBFLRanking<T> localize(final ILocalizerCache<T> localizer, ComputationStrategies strategy) { // == 1. Create Weka training instance final List<INode<T>> nodes = new ArrayList<>(localizer.getNodes()); // nominal true/false values final List<String> tf = new ArrayList<>(); tf.add("t");//from w ww.j a v a 2 s .co m tf.add("f"); // create an attribute for each component final Map<INode<T>, Attribute> attributeMap = new HashMap<>(); final ArrayList<Attribute> attributeList = new ArrayList<>(); // NOCS: Weka needs ArrayList.. for (final INode<T> node : nodes) { final Attribute attribute = new Attribute(node.toString(), tf); attributeList.add(attribute); attributeMap.put(node, attribute); } // create class attribute (trace success) final Attribute successAttribute = new Attribute("success", tf); attributeList.add(successAttribute); // create weka training instance final Instances trainingSet = new Instances("TraceInfoInstances", attributeList, 1); trainingSet.setClassIndex(attributeList.size() - 1); // == 2. add traces to training set // add an instance for each trace for (final ITrace<T> trace : localizer.getTraces()) { final Instance instance = new DenseInstance(nodes.size() + 1); instance.setDataset(trainingSet); for (final INode<T> node : nodes) { instance.setValue(attributeMap.get(node), trace.isInvolved(node) ? "t" : "f"); } instance.setValue(successAttribute, trace.isSuccessful() ? "t" : "f"); trainingSet.add(instance); } // == 3. use prediction to localize faults // build classifier try { final Classifier classifier = this.buildClassifier(this.classifierName, this.classifierOptions, trainingSet); final SBFLRanking<T> ranking = new SBFLRanking<>(); Log.out(this, "begin classifying"); int classified = 0; final Instance instance = new DenseInstance(nodes.size() + 1); instance.setDataset(trainingSet); for (final INode<T> node : nodes) { instance.setValue(attributeMap.get(node), "f"); } instance.setValue(successAttribute, "f"); for (final INode<T> node : nodes) { classified++; if (classified % 1000 == 0) { Log.out(this, String.format("Classified %d nodes.", classified)); } // contain only the current node in the network instance.setValue(attributeMap.get(node), "t"); // predict with which probability this setup leads to a failing network final double[] distribution = classifier.distributionForInstance(instance); ranking.add(node, distribution[1]); // reset involvment for node instance.setValue(attributeMap.get(node), "f"); } return ranking; } catch (final Exception e) { // NOCS: Weka throws only raw exceptions throw new RuntimeException(e); } }
From source file:se.de.hu_berlin.informatik.stardust.localizer.machinelearn.WekaFaultLocalizer.java
License:Open Source License
@Override public SBFLRanking<T> localize(final ISpectra<T> spectra) { // == 1. Create Weka training instance final List<INode<T>> nodes = new ArrayList<>(spectra.getNodes()); // nominal true/false values final List<String> tf = new ArrayList<String>(); tf.add("t");// w w w .jav a 2 s .c o m tf.add("f"); // create an attribute for each component final Map<INode<T>, Attribute> attributeMap = new HashMap<INode<T>, Attribute>(); final ArrayList<Attribute> attributeList = new ArrayList<Attribute>(); // NOCS: Weka needs ArrayList.. for (final INode<T> node : nodes) { final Attribute attribute = new Attribute(node.toString(), tf); attributeList.add(attribute); attributeMap.put(node, attribute); } // create class attribute (trace success) final Attribute successAttribute = new Attribute("success", tf); attributeList.add(successAttribute); // create weka training instance final Instances trainingSet = new Instances("TraceInfoInstances", attributeList, 1); trainingSet.setClassIndex(attributeList.size() - 1); // == 2. add traces to training set // add an instance for each trace for (final ITrace<T> trace : spectra.getTraces()) { final Instance instance = new DenseInstance(nodes.size() + 1); instance.setDataset(trainingSet); for (final INode<T> node : nodes) { instance.setValue(attributeMap.get(node), trace.isInvolved(node) ? "t" : "f"); } instance.setValue(successAttribute, trace.isSuccessful() ? "t" : "f"); trainingSet.add(instance); } // == 3. use prediction to localize faults // build classifier try { final Classifier classifier = this.buildClassifier(this.classifierName, this.classifierOptions, trainingSet); final SBFLRanking<T> ranking = new SBFLRanking<>(); Log.out(this, "begin classifying"); int classified = 0; final Instance instance = new DenseInstance(nodes.size() + 1); instance.setDataset(trainingSet); for (final INode<T> node : nodes) { instance.setValue(attributeMap.get(node), "f"); } instance.setValue(successAttribute, "f"); for (final INode<T> node : nodes) { classified++; if (classified % 1000 == 0) { Log.out(this, String.format("Classified %d nodes.", classified)); } // contain only the current node in the network instance.setValue(attributeMap.get(node), "t"); // predict with which probability this setup leads to a failing network final double[] distribution = classifier.distributionForInstance(instance); ranking.add(node, distribution[1]); // reset involvment for node instance.setValue(attributeMap.get(node), "f"); } return ranking; } catch (final Exception e) { // NOCS: Weka throws only raw exceptions throw new RuntimeException(e); } }
From source file:sg.edu.nus.comp.nlp.ims.classifiers.CWekaEvaluator.java
License:Open Source License
@Override public Object evaluate(Object p_Lexelt) throws Exception { ILexelt lexelt = (ILexelt) p_Lexelt; String lexeltID = lexelt.getID(); IStatistic stat = (IStatistic) this.getStatistic(lexeltID); int type = 2; String firstSense = this.m_UnknownSense; if (stat == null) { type = 1;// w w w . j a va 2 s. c om if (this.m_SenseIndex != null) { String first = this.m_SenseIndex.getFirstSense(lexeltID); if (first != null) { firstSense = first; } } } else { if (stat.getTags().size() == 1) { type = 1; firstSense = stat.getTags().iterator().next(); } else { type = stat.getTags().size(); } } int classIdx = this.m_ClassIndex; CResultInfo retVal = new CResultInfo(); switch (type) { case 0: throw new Exception("no tag for lexelt " + lexeltID + "."); case 1: retVal.lexelt = lexelt.getID(); retVal.docs = new String[lexelt.size()]; retVal.ids = new String[lexelt.size()]; retVal.classes = new String[] { firstSense }; retVal.probabilities = new double[lexelt.size()][1]; for (int i = 0; i < retVal.probabilities.length; i++) { retVal.probabilities[i][0] = 1; retVal.docs[i] = lexelt.getInstanceDocID(i); retVal.ids[i] = lexelt.getInstanceID(i); } break; default: lexelt.setStatistic(stat); Classifier classifier = (Classifier) this.getModel(lexeltID); ILexeltWriter lexeltWriter = new CWekaSparseLexeltWriter(); Instances instances = (Instances) lexeltWriter.getInstances(lexelt); if (classIdx < 0) { classIdx = instances.numAttributes() - 1; } instances.setClassIndex(classIdx); retVal.lexelt = lexelt.getID(); retVal.docs = new String[lexelt.size()]; retVal.ids = new String[lexelt.size()]; retVal.probabilities = new double[instances.numInstances()][]; retVal.classes = new String[instances.classAttribute().numValues()]; for (int i = 0; i < instances.classAttribute().numValues(); i++) { retVal.classes[i] = instances.classAttribute().value(i); } if (instances.classAttribute().isNumeric()) { for (int i = 0; i < instances.numInstances(); i++) { Instance instance = instances.instance(i); retVal.docs[i] = lexelt.getInstanceDocID(i); retVal.ids[i] = lexelt.getInstanceID(i); retVal.probabilities[i] = new double[retVal.classes.length]; retVal.probabilities[i][(int) classifier.classifyInstance(instance)] = 1; } } else { for (int i = 0; i < instances.numInstances(); i++) { Instance instance = instances.instance(i); retVal.docs[i] = lexelt.getInstanceDocID(i); retVal.ids[i] = lexelt.getInstanceID(i); retVal.probabilities[i] = classifier.distributionForInstance(instance); } } } return retVal; }
From source file:sirius.clustering.main.ClustererClassificationPane.java
License:Open Source License
private void start() { //Run Classifier if (this.inputDirectoryTextField.getText().length() == 0) { JOptionPane.showMessageDialog(parent, "Please set Input Directory to where the clusterer output are!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return;//w ww . j a v a 2s .c o m } if (m_ClassifierEditor.getValue() == null) { JOptionPane.showMessageDialog(parent, "Please choose Classifier!", "Evaluate Classifier", JOptionPane.ERROR_MESSAGE); return; } if (validateStatsSettings(1) == false) { return; } if (this.clusteringClassificationThread == null) { startButton.setEnabled(false); stopButton.setEnabled(true); tabbedClassifierPane.setSelectedIndex(0); this.clusteringClassificationThread = (new Thread() { public void run() { //Clear the output text area levelOneClassifierOutputTextArea.setText(""); resultsTableModel.reset(); //double threshold = Double.parseDouble(classifierOneThresholdTextField.getText()); //cross-validation int numFolds; if (jackKnifeRadioButton.isSelected()) numFolds = -1; else numFolds = Integer.parseInt(foldsField.getText()); StringTokenizer st = new StringTokenizer(inputDirectoryTextField.getText(), File.separator); String filename = ""; while (st.hasMoreTokens()) { filename = st.nextToken(); } StringTokenizer st2 = new StringTokenizer(filename, "_."); numOfCluster = 0; if (st2.countTokens() >= 2) { st2.nextToken(); String numOfClusterString = st2.nextToken().replaceAll("cluster", ""); try { numOfCluster = Integer.parseInt(numOfClusterString); } catch (NumberFormatException e) { JOptionPane.showMessageDialog(parent, "Please choose the correct file! (Output from Utilize Clusterer)", "ERROR", JOptionPane.ERROR_MESSAGE); } } Classifier template = (Classifier) m_ClassifierEditor.getValue(); for (int x = 0; x <= numOfCluster && clusteringClassificationThread != null; x++) {//Test each cluster try { long totalTimeStart = 0, totalTimeElapsed = 0; totalTimeStart = System.currentTimeMillis(); statusLabel.setText("Reading in cluster" + x + " file.."); String inputFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".arff"); String outputScoreFilename = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); BufferedWriter output = new BufferedWriter(new FileWriter(outputScoreFilename)); Instances inst = new Instances(new FileReader(inputFilename)); //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files inst.setClassIndex(inst.numAttributes() - 1); Random random = new Random(1);//Simply set to 1, shall implement the random seed option later inst.randomize(random); if (inst.attribute(inst.classIndex()).isNominal()) inst.stratify(numFolds); // for timing ClassifierResults classifierResults = new ClassifierResults(false, 0); String classifierName = m_ClassifierEditor.getValue().getClass().getName(); classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName); classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ", inputFilename); classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ", "NA"); //ArrayList<Double> resultList = new ArrayList<Double>(); if (jackKnifeRadioButton.isSelected() || numFolds > inst.numInstances() - 1) numFolds = inst.numInstances() - 1; for (int fold = 0; fold < numFolds && clusteringClassificationThread != null; fold++) {//Doing cross-validation statusLabel.setText("Cluster: " + x + " - Training Fold " + (fold + 1) + ".."); Instances train = inst.trainCV(numFolds, fold, random); Classifier current = null; try { current = Classifier.makeCopy(template); current.buildClassifier(train); Instances test = inst.testCV(numFolds, fold); statusLabel.setText("Cluster: " + x + " - Testing Fold " + (fold + 1) + ".."); for (int jj = 0; jj < test.numInstances(); jj++) { double[] result = current.distributionForInstance(test.instance(jj)); output.write("Cluster: " + x); output.newLine(); output.newLine(); output.write(test.instance(jj).stringValue(test.classAttribute()) + ",0=" + result[0]); output.newLine(); } } catch (Exception ex) { ex.printStackTrace(); statusLabel.setText("Error in cross-validation!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } output.close(); totalTimeElapsed = System.currentTimeMillis() - totalTimeStart; classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ", Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes " + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds"); double threshold = validateFieldAsThreshold(classifierOneThresholdTextField.getText(), "Threshold Field", classifierOneThresholdTextField); String filename2 = inputDirectoryTextField.getText() .replaceAll("_cluster" + numOfCluster + ".arff", "_cluster" + x + ".score"); PredictionStats classifierStats = new PredictionStats(filename2, 0, threshold); resultsTableModel.add("Cluster " + x, classifierResults, classifierStats); resultsTable.setRowSelectionInterval(x, x); computeStats(numFolds);//compute and display the results } catch (Exception e) { e.printStackTrace(); statusLabel.setText("Error in reading file!"); startButton.setEnabled(true); stopButton.setEnabled(false); } } //end of cluster for loop resultsTableModel.add("Summary - Equal Weightage", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 1, numOfCluster + 1); computeStats(numFolds); resultsTableModel.add("Summary - Weighted Average", null, null); resultsTable.setRowSelectionInterval(numOfCluster + 2, numOfCluster + 2); computeStats(numFolds); if (clusteringClassificationThread != null) statusLabel.setText("Done!"); else statusLabel.setText("Interrupted.."); startButton.setEnabled(true); stopButton.setEnabled(false); if (classifierOne != null) { levelOneClassifierOutputScrollPane.getVerticalScrollBar() .setValue(levelOneClassifierOutputScrollPane.getVerticalScrollBar().getMaximum()); } clusteringClassificationThread = null; } }); this.clusteringClassificationThread.setPriority(Thread.MIN_PRIORITY); this.clusteringClassificationThread.start(); } else { JOptionPane.showMessageDialog(parent, "Cannot start new job as previous job still running. Click stop to terminate previous job", "ERROR", JOptionPane.ERROR_MESSAGE); } }
From source file:sirius.predictor.main.PredictorFrame.java
License:Open Source License
private void runType3Classifier(ClassifierData classifierData) { /*//from ww w .j a v a 2 s . c o m * This is for type3 classifier * Note that all position and motif list only does not apply to this classifier as * it will only give one score for each sequence */ if (sequenceNameTableModel.getRowCount() < 1) { JOptionPane.showMessageDialog(this, "Please load File first!", "No Sequence", JOptionPane.INFORMATION_MESSAGE); return; } if (loadFastaFileMenuItem.getState() == false) { JOptionPane.showMessageDialog(this, "Please load Fasta File! Currently, you have score file!", "Wrong File Format", JOptionPane.INFORMATION_MESSAGE); return; } if (onAllPositionsMenuItem.getState() == false) { JOptionPane.showMessageDialog(this, "For type 3 classifier, it make only one prediction a sequence", "Information", JOptionPane.INFORMATION_MESSAGE); } try { BufferedWriter output = new BufferedWriter(new FileWriter( outputDirectory + File.separator + "classifierone_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores")); Classifier classifierOne = classifierData.getClassifierOne(); //Reading and Storing the featureList Instances inst = classifierData.getInstances(); ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute must be ignored featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name())); } //Going through each and every sequence for (int x = 0; x < sequenceNameTableModel.getRowCount(); x++) { if (stopClassifier == true) { statusPane.setText("Running of Classifier Stopped!"); stopClassifier = false; output.close(); return; } //if(x%100 == 0) statusPane.setText("Running " + classifierData.getClassifierName() + " - ClassifierOne @ " + x + " / " + sequenceNameTableModel.getRowCount()); //Header output.write(sequenceNameTableModel.getHeader(x)); output.newLine(); output.write(sequenceNameTableModel.getSequence(x)); output.newLine(); //Sequence Score -> index-score, index-score String sequence = sequenceNameTableModel.getSequence(x); Instance tempInst; tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); for (int z = 0; z < inst.numAttributes() - 1; z++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount("+1_Index(-1)", sequence, featureDataArrayList.get(z), classifierData.getScoringMatrixIndex(), classifierData.getCountingStyleIndex(), classifierData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(z, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(z, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(z, (String) obj); else { output.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } //note that pos or neg does not matter as this is not used tempInst.setValue(inst.numAttributes() - 1, "pos"); try { double[] results = classifierOne.distributionForInstance(tempInst); output.write("0=" + results[0]); } catch (Exception e) { //this is to ensure that the run will continue output.write("0=-0.0"); //change throw error to screen output if i want the run to continue System.err .println("Exception has Occurred for classifierOne.distributionForInstance(tempInst);"); } output.newLine(); output.flush(); } output.flush(); output.close(); statusPane.setText("ClassifierOne finished running..."); loadScoreFile(outputDirectory + File.separator + "classifierone_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores"); } catch (Exception e) { JOptionPane.showMessageDialog(null, "Exception Occured", "Error", JOptionPane.ERROR_MESSAGE); e.printStackTrace(); } }
From source file:sirius.predictor.main.PredictorFrame.java
License:Open Source License
private void runClassifier(ClassifierData classifierData, boolean allPositions) { //this method is for type 1 classifier with all positions and motif list //and type 2 classifier with all positions if (sequenceNameTableModel.getRowCount() < 1) { JOptionPane.showMessageDialog(this, "Please load File first!", "No Sequence", JOptionPane.INFORMATION_MESSAGE); return;// w ww.j a v a 2s. co m } if (loadFastaFileMenuItem.getState() == false) { JOptionPane.showMessageDialog(this, "Please load Fasta File! Currently, you have score file!", "Wrong File Format", JOptionPane.INFORMATION_MESSAGE); return; } if (onAllPositionsMenuItem.getState() == false && motifListTableModel.getSize() == 0) { JOptionPane.showMessageDialog(this, "There are no Motifs chosen in Motif List!", "No Motifs", JOptionPane.INFORMATION_MESSAGE); MotifListDialog dialog = new MotifListDialog(motifListTableModel); dialog.setLocationRelativeTo(this); dialog.setVisible(true); return; } while (outputDirectory == null) { JOptionPane.showMessageDialog(this, "Please set output directory first!", "Output Directory not set", JOptionPane.INFORMATION_MESSAGE); setOutputDirectory(); //return; } try { BufferedWriter output = new BufferedWriter(new FileWriter( outputDirectory + File.separator + "classifierone_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores")); Classifier classifierOne = classifierData.getClassifierOne(); int leftMostPosition = classifierData.getLeftMostPosition(); int rightMostPosition = classifierData.getRightMostPosition(); //Reading and Storing the featureList Instances inst = classifierData.getInstances(); ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute must be ignored featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name())); } for (int x = 0; x < sequenceNameTableModel.getRowCount(); x++) { if (stopClassifier == true) { statusPane.setText("Running of Classifier Stopped!"); stopClassifier = false; output.close(); return; } //if(x%100 == 0) statusPane.setText("Running " + classifierData.getClassifierName() + " - ClassifierOne @ " + x + " / " + sequenceNameTableModel.getRowCount()); //Header output.write(sequenceNameTableModel.getHeader(x)); output.newLine(); output.write(sequenceNameTableModel.getSequence(x)); output.newLine(); //Sequence Score -> index-score, index-score String sequence = sequenceNameTableModel.getSequence(x); int minSequenceLengthRequired; int targetLocationIndex; if (leftMostPosition < 0 && rightMostPosition > 0) {// -ve and +ve minSequenceLengthRequired = (leftMostPosition * -1) + rightMostPosition; targetLocationIndex = (leftMostPosition * -1); } else if (leftMostPosition < 0 && rightMostPosition < 0) {//-ve and -ve minSequenceLengthRequired = rightMostPosition - leftMostPosition + 1; targetLocationIndex = (leftMostPosition * -1); } else {//+ve and +ve minSequenceLengthRequired = rightMostPosition - leftMostPosition + 1; targetLocationIndex = (leftMostPosition * -1); } boolean firstEntryForClassifierOne = true; for (int y = 0; y + (minSequenceLengthRequired - 1) < sequence.length(); y++) { //Check if targetLocation match any motif in motif List if (allPositions == false && motifListTableModel .gotMotifMatch(sequence.substring(y + 0, y + targetLocationIndex)) == false) continue; String line2 = sequence.substring(y + 0, y + minSequenceLengthRequired); Instance tempInst; tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); for (int z = 0; z < inst.numAttributes() - 1; z++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount("+1_Index(" + targetLocationIndex + ")", line2, featureDataArrayList.get(z), classifierData.getScoringMatrixIndex(), classifierData.getCountingStyleIndex(), classifierData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(z, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(z, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(z, (String) obj); else { output.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } //note that pos or neg does not matter as this is not used tempInst.setValue(inst.numAttributes() - 1, "neg"); double[] results = classifierOne.distributionForInstance(tempInst); if (firstEntryForClassifierOne) firstEntryForClassifierOne = false; else output.write(","); output.write(y + targetLocationIndex + "=" + results[0]); } output.newLine(); output.flush(); } output.flush(); output.close(); statusPane.setText("ClassifierOne finished running..."); //Run classifier Two if it is type 2 if (classifierData.getClassifierType() == 2) { BufferedWriter output2 = new BufferedWriter(new FileWriter( outputDirectory + File.separator + "classifiertwo_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores")); BufferedReader input2 = new BufferedReader(new FileReader( outputDirectory + File.separator + "classifierone_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores")); Classifier classifierTwo = classifierData.getClassifierTwo(); Instances inst2 = classifierData.getInstances2(); int setUpstream = classifierData.getSetUpstream(); int setDownstream = classifierData.getSetDownstream(); int minScoreWindowRequired; if (setUpstream < 0 && setDownstream < 0) {//-ve and -ve minScoreWindowRequired = setDownstream - setUpstream + 1; } else if (setUpstream < 0 && setDownstream > 0) {//-ve and +ve minScoreWindowRequired = (setUpstream * -1) + setDownstream; } else {//+ve and +ve minScoreWindowRequired = setDownstream - setUpstream + 1; } String lineHeader; String lineSequence; int lineCounter2 = 0; while ((lineHeader = input2.readLine()) != null) { if (stopClassifier == true) { statusPane.setText("Running of Classifier Stopped!"); stopClassifier = false; output2.close(); input2.close(); return; } //if(lineCounter2%100 == 0) statusPane.setText("Running " + classifierData.getClassifierName() + " - ClassifierTwo @ " + lineCounter2 + " / " + sequenceNameTableModel.getRowCount()); lineSequence = input2.readLine(); output2.write(lineHeader); output2.newLine(); output2.write(lineSequence); output2.newLine(); StringTokenizer locationScore = new StringTokenizer(input2.readLine(), ","); int totalTokens = locationScore.countTokens(); String[][] scores = new String[totalTokens][2]; int scoreIndex = 0; while (locationScore.hasMoreTokens()) { StringTokenizer locationScoreToken = new StringTokenizer(locationScore.nextToken(), "="); scores[scoreIndex][0] = locationScoreToken.nextToken();//location scores[scoreIndex][1] = locationScoreToken.nextToken();//score scoreIndex++; } int targetLocationIndex2; if (setUpstream == 0 || setDownstream == 0) { output2.close(); input2.close(); throw new Exception("setUpstream == 0 || setDownstream == 0"); } if (setUpstream < 0) { targetLocationIndex2 = Integer.parseInt(scores[0][0]) + (-setUpstream); } else {//setUpstream > 0 targetLocationIndex2 = Integer.parseInt(scores[0][0]); //first location } for (int x = 0; x + minScoreWindowRequired - 1 < totalTokens; x++) { //+1 is for the class index if (x != 0) output2.write(","); Instance tempInst2 = new Instance(minScoreWindowRequired + 1); tempInst2.setDataset(inst2); for (int y = 0; y < minScoreWindowRequired; y++) { tempInst2.setValue(y, Double.parseDouble(scores[x + y][1])); } tempInst2.setValue(tempInst2.numAttributes() - 1, "pos"); double[] results = classifierTwo.distributionForInstance(tempInst2); output2.write(targetLocationIndex2 + "=" + results[0]); targetLocationIndex2++; } lineCounter2++; output2.newLine(); } input2.close(); output2.close(); statusPane.setText("ClassifierTwo finished running..."); } if (classifierData.getClassifierType() == 1) loadScoreFile( outputDirectory + File.separator + "classifierone_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores"); else loadScoreFile( outputDirectory + File.separator + "classifiertwo_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores"); } catch (Exception e) { JOptionPane.showMessageDialog(null, "Exception Occured", "Error", JOptionPane.ERROR_MESSAGE); e.printStackTrace(); } }
From source file:sirius.predictor.main.PredictorFrame.java
License:Open Source License
private void runType2ClassifierWithMotifList(ClassifierData classifierData) { //Checking.. if (sequenceNameTableModel.getRowCount() < 1) { JOptionPane.showMessageDialog(this, "Please load File first!", "No Sequence", JOptionPane.INFORMATION_MESSAGE); return;/*w w w .j a v a 2s . c om*/ } if (loadFastaFileMenuItem.getState() == false) { JOptionPane.showMessageDialog(this, "Please load Fasta File! Currently, you have score file!", "Wrong File Format", JOptionPane.INFORMATION_MESSAGE); return; } if (motifListTableModel.getSize() == 0) { JOptionPane.showMessageDialog(this, "There are no Motifs chosen in Motif List!", "No Motifs", JOptionPane.INFORMATION_MESSAGE); MotifListDialog dialog = new MotifListDialog(motifListTableModel); dialog.setLocationRelativeTo(this); dialog.setVisible(true); return; } //Proper running start try { //classifierOne score output BufferedWriter output = new BufferedWriter(new FileWriter( outputDirectory + File.separator + "classifierone_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores")); Classifier classifierOne = classifierData.getClassifierOne(); int leftMostPosition = classifierData.getLeftMostPosition(); int rightMostPosition = classifierData.getRightMostPosition(); //Reading and Storing the featureList Instances inst = classifierData.getInstances(); ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute must be ignored featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name())); } //initialization for type 2 classifier BufferedWriter output2 = new BufferedWriter(new FileWriter( outputDirectory + File.separator + "classifiertwo_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores")); int setUpstream = classifierData.getSetUpstream(); int setDownstream = classifierData.getSetDownstream(); int minScoreWindowRequired; if (setUpstream < 0 && setDownstream < 0) {//-ve and -ve minScoreWindowRequired = setDownstream - setUpstream + 1; } else if (setUpstream < 0 && setDownstream > 0) {//-ve and +ve minScoreWindowRequired = (setUpstream * -1) + setDownstream; } else {//+ve and +ve minScoreWindowRequired = setDownstream - setUpstream + 1; } Classifier classifierTwo = classifierData.getClassifierTwo(); Instances inst2 = classifierData.getInstances2(); if (setUpstream == 0 || setDownstream == 0) { output.close(); output2.close(); throw new Exception("setUpstream == 0 || setDownstream == 0"); } //for each sequence for (int x = 0; x < sequenceNameTableModel.getRowCount(); x++) { if (stopClassifier == true) { statusPane.setText("Running of Classifier Stopped!"); stopClassifier = false; output.close(); output2.close(); return; } //if(x%100 == 0) statusPane.setText("Running " + classifierData.getClassifierName() + " - ClassifierOne @ " + x + " / " + sequenceNameTableModel.getRowCount()); //Header output.write(sequenceNameTableModel.getHeader(x)); output.newLine(); output.write(sequenceNameTableModel.getSequence(x)); output.newLine(); output2.write(sequenceNameTableModel.getHeader(x)); output2.newLine(); output2.write(sequenceNameTableModel.getSequence(x)); output2.newLine(); //Sequence Score -> index-score, index-score String sequence = sequenceNameTableModel.getSequence(x); int minSequenceLengthRequired; int targetLocationIndex; //set the targetLocationIndex and minSequenceLengthRequired if (leftMostPosition < 0 && rightMostPosition > 0) {// -ve and +ve minSequenceLengthRequired = (leftMostPosition * -1) + rightMostPosition; targetLocationIndex = (leftMostPosition * -1); } else if (leftMostPosition < 0 && rightMostPosition < 0) {//-ve and -ve minSequenceLengthRequired = rightMostPosition - leftMostPosition + 1; targetLocationIndex = (leftMostPosition * -1); } else {//+ve and +ve minSequenceLengthRequired = rightMostPosition - leftMostPosition + 1; targetLocationIndex = (leftMostPosition * -1); } //This hashtable is used to ensure that on positions where predictions are already made, //we just skip. This will happen only if it is a type 2 classifier Hashtable<Integer, Double> scoreTable = new Hashtable<Integer, Double>(); boolean firstEntryForClassifierOne = true; boolean firstEntryForClassifierTwo = true; for (int y = 0; y + (minSequenceLengthRequired - 1) < sequence.length(); y++) { int endPoint = y;//endPoint should be the exact position int currentY = y; int startPoint = y; //run only on Motifs? if (onMotifsOnlyMenuItem.getState()) { //Check if targetLocation match any motif in motif List if (motifListTableModel .gotMotifMatch(sequence.substring(y + 0, y + targetLocationIndex)) == false) continue; //position not found in motif list else //rollback to upstream and make prediction all the way till downstream //needed for type 2 classifier currentY += setUpstream; if (setUpstream > 0) currentY--; startPoint = currentY; //note that y starts from 0 so y is surely >= 0 endPoint += setDownstream; if (setDownstream > 0) endPoint--; //check still within bound of the sequence if (startPoint < 0 || endPoint >= sequence.length() - (minSequenceLengthRequired - 1)) continue;//out of bounds } while (currentY <= endPoint) { if (scoreTable.get(currentY + targetLocationIndex) != null) { currentY++; continue; } String line2 = sequence.substring(currentY + 0, currentY + minSequenceLengthRequired); Instance tempInst; tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); for (int z = 0; z < inst.numAttributes() - 1; z++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount("+1_Index(" + targetLocationIndex + ")", line2, featureDataArrayList.get(z), classifierData.getScoringMatrixIndex(), classifierData.getCountingStyleIndex(), classifierData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(z, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(z, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(z, (String) obj); else { output.close(); output2.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } //note that pos or neg does not matter as this is not used tempInst.setValue(inst.numAttributes() - 1, "neg"); double[] results = classifierOne.distributionForInstance(tempInst); if (firstEntryForClassifierOne) firstEntryForClassifierOne = false; else output.write(","); output.write(currentY + targetLocationIndex + "=" + results[0]); scoreTable.put(currentY + targetLocationIndex, results[0]); currentY++; } Instance tempInst2 = new Instance(minScoreWindowRequired + 1);//+1 for class attribute tempInst2.setDataset(inst2); int indexForClassifier2Inst = 0; for (int z = startPoint; z <= endPoint; z++) { tempInst2.setValue(indexForClassifier2Inst, scoreTable.get(targetLocationIndex + z)); indexForClassifier2Inst++; } //note that pos or neg does not matter as this is not used tempInst2.setValue(tempInst2.numAttributes() - 1, "pos"); double[] results = classifierTwo.distributionForInstance(tempInst2); if (firstEntryForClassifierTwo == true) firstEntryForClassifierTwo = false; else output2.write(","); output2.write(y + targetLocationIndex + "=" + results[0]); } //end of for loop output2.newLine(); output2.flush(); output.newLine(); output.flush(); } output.close(); output2.close(); statusPane.setText("Classifier Finished running..."); loadScoreFile(outputDirectory + File.separator + "classifiertwo_" + classifierData.getClassifierName() + "_" + classifierData.getClassifierType() + "_" + fastaFilename + ".scores"); } catch (Exception e) { JOptionPane.showMessageDialog(null, "Exception Occured", "Error", JOptionPane.ERROR_MESSAGE); e.printStackTrace(); } }
From source file:sirius.trainer.step4.DatasetGenerator.java
License:Open Source License
public static boolean generateDataset2(JInternalFrame parent, ApplicationData applicationData, int classifierTwoUpstream, int classifierTwoDownstream, Classifier classifierOne) { try {//from w w w. j a va 2 s .c o m StatusPane statusPane = applicationData.getStatusPane(); int positiveDataset2FromInt = applicationData.getPositiveDataset2FromField(); int positiveDataset2ToInt = applicationData.getPositiveDataset2ToField(); int negativeDataset2FromInt = applicationData.getNegativeDataset2FromField(); int negativeDataset2ToInt = applicationData.getNegativeDataset2ToField(); int totalDataset2PositiveInstances = positiveDataset2ToInt - positiveDataset2FromInt + 1; int totalDataset2NegativeInstances = negativeDataset2ToInt - negativeDataset2FromInt + 1; int totalDataset2Instances = totalDataset2PositiveInstances + totalDataset2NegativeInstances; int scoringMatrixIndex = applicationData.getScoringMatrixIndex(); int countingStyleIndex = applicationData.getCountingStyleIndex(); //Generate the header for Dataset2.arff BufferedWriter dataset2OutputFile = new BufferedWriter( new FileWriter(applicationData.getWorkingDirectory() + File.separator + "Dataset2.arff")); dataset2OutputFile.write("@relation 'Dataset2.arff' "); dataset2OutputFile.newLine(); dataset2OutputFile.newLine(); dataset2OutputFile.flush(); for (int x = classifierTwoUpstream; x <= classifierTwoDownstream; x++) { if (x != 0) {//This statment is used because in sequence position only -1,+1 dun have 0 dataset2OutputFile.write("@attribute (" + x + ") numeric"); dataset2OutputFile.newLine(); dataset2OutputFile.flush(); } } if (positiveDataset2FromInt > 0 && negativeDataset2FromInt > 0) dataset2OutputFile.write("@attribute Class {pos,neg}"); else if (positiveDataset2FromInt > 0 && negativeDataset2FromInt == 0) dataset2OutputFile.write("@attribute Class {pos}"); else if (positiveDataset2FromInt == 0 && negativeDataset2FromInt > 0) dataset2OutputFile.write("@attribute Class {neg}"); dataset2OutputFile.newLine(); dataset2OutputFile.newLine(); dataset2OutputFile.write("@data"); dataset2OutputFile.newLine(); dataset2OutputFile.newLine(); dataset2OutputFile.flush(); //Generating an Instance given a sequence with the current attributes //for dataset2.arff //Need this for parameter setting for tempInst Instances inst = applicationData.getDataset1Instances(); inst.deleteAttributeType(Attribute.STRING); FastaFileManipulation fastaFile = new FastaFileManipulation( applicationData.getPositiveStep1TableModel(), applicationData.getNegativeStep1TableModel(), positiveDataset2FromInt, positiveDataset2ToInt, negativeDataset2FromInt, negativeDataset2ToInt, applicationData.getWorkingDirectory()); //Reading and Storing the featureList ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>(); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute must be ignored featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name())); } //Reading the fastaFile int lineCounter = 0; String _class = "pos"; FastaFormat fastaFormat; while ((fastaFormat = fastaFile.nextSequence(_class)) != null) { if (applicationData.terminateThread == true) { statusPane.setText("Interrupted - Classifier Two Training Not Complete"); dataset2OutputFile.close(); return false; } lineCounter++;//Putting it here will mean if lineCounter is x then line == sequence x //if((lineCounter % 100) == 0){ dataset2OutputFile.flush(); statusPane.setText("Generating Dataset2.arff.. @ " + lineCounter + " / " + totalDataset2Instances + " Sequences"); //} //For each sequence, you want to shift from upstream till downstream //ie changing the +1 location //to get the scores given by classifier one so that you can use it to train classifier two later //Doing shift from upstream till downstream SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), classifierTwoUpstream, classifierTwoDownstream); String line2; while ((line2 = seq.nextShift()) != null) { Instance tempInst; tempInst = new Instance(inst.numAttributes()); tempInst.setDataset(inst); for (int x = 0; x < inst.numAttributes() - 1; x++) { //-1 because class attribute can be ignored //Give the sequence and the featureList to get the feature freqs on the sequence Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2, featureDataArrayList.get(x), scoringMatrixIndex, countingStyleIndex, applicationData.getScoringMatrix()); if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer")) tempInst.setValue(x, (Integer) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double")) tempInst.setValue(x, (Double) obj); else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String")) tempInst.setValue(x, (String) obj); else { dataset2OutputFile.close(); throw new Error("Unknown: " + obj.getClass().getName()); } } tempInst.setValue(inst.numAttributes() - 1, _class); double[] results = classifierOne.distributionForInstance(tempInst); dataset2OutputFile.write("" + results[0] + ","); } dataset2OutputFile.write(_class); dataset2OutputFile.newLine(); if (lineCounter == totalDataset2PositiveInstances) _class = "neg"; } dataset2OutputFile.close(); fastaFile.cleanUp(); } catch (Exception e) { e.printStackTrace(); JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE); applicationData.getStatusPane().setText("Error - Classifier Two Training Not Complete"); return false; } return true; }