Example usage for weka.classifiers Classifier distributionForInstance

List of usage examples for weka.classifiers Classifier distributionForInstance

Introduction

In this page you can find the example usage for weka.classifiers Classifier distributionForInstance.

Prototype

public double[] distributionForInstance(Instance instance) throws Exception;

Source Link

Document

Predicts the class memberships for a given instance.

Usage

From source file:sirius.trainer.step4.RunClassifier.java

License:Open Source License

public static Classifier startClassifierOne(JInternalFrame parent, ApplicationData applicationData,
        JTextArea classifierOneDisplayTextArea, GenericObjectEditor m_ClassifierEditor, GraphPane myGraph,
        boolean test, ClassifierResults classifierResults, int range, double threshold) {
    try {//w w  w .java 2 s  . c  o m
        StatusPane statusPane = applicationData.getStatusPane();

        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        //Setting up training dataset 1 for classifier one
        statusPane.setText("Setting up...");
        //Load Dataset1 Instances
        Instances inst = new Instances(applicationData.getDataset1Instances());
        inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);
        applicationData.getDataset1Instances()
                .setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);
        // for timing
        long trainTimeStart = 0, trainTimeElapsed = 0;
        Classifier classifierOne = (Classifier) m_ClassifierEditor.getValue();
        statusPane.setText("Training Classifier One... May take a while... Please wait...");
        trainTimeStart = System.currentTimeMillis();
        inst.deleteAttributeType(Attribute.STRING);
        classifierOne.buildClassifier(inst);
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;

        String classifierName = m_ClassifierEditor.getValue().getClass().getName();
        classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName);
        classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                applicationData.getWorkingDirectory() + File.separator + "Dataset1.arff");
        classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");

        if (test == false) {
            statusPane.setText("Classifier One Training Completed...Done...");
            return classifierOne;
        }
        if (applicationData.terminateThread == true) {
            statusPane.setText("Interrupted - Classifier One Training Completed");
            return classifierOne;
        }
        //Running classifier one on dataset3
        if (statusPane != null)
            statusPane.setText("Running ClassifierOne on Dataset 3..");
        //Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        //Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();   
        int positiveDataset3FromInt = applicationData.getPositiveDataset3FromField();
        int positiveDataset3ToInt = applicationData.getPositiveDataset3ToField();
        int negativeDataset3FromInt = applicationData.getNegativeDataset3FromField();
        int negativeDataset3ToInt = applicationData.getNegativeDataset3ToField();

        //Generate the header for ClassifierOne.scores on Dataset3                
        BufferedWriter dataset3OutputFile = new BufferedWriter(new FileWriter(
                applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores"));
        if (m_ClassifierEditor.getValue() instanceof OptionHandler)
            classifierName += " "
                    + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions());

        FastaFileManipulation fastaFile = new FastaFileManipulation(
                applicationData.getPositiveStep1TableModel(), applicationData.getNegativeStep1TableModel(),
                positiveDataset3FromInt, positiveDataset3ToInt, negativeDataset3FromInt, negativeDataset3ToInt,
                applicationData.getWorkingDirectory());

        //Reading and Storing the featureList
        ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>();
        for (int x = 0; x < inst.numAttributes() - 1; x++) {
            //-1 because class attribute must be ignored
            featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name()));
        }

        //Reading the fastaFile         
        int lineCounter = 0;
        String _class = "pos";
        int totalDataset3PositiveInstances = positiveDataset3ToInt - positiveDataset3FromInt + 1;
        FastaFormat fastaFormat;
        while ((fastaFormat = fastaFile.nextSequence(_class)) != null) {
            if (applicationData.terminateThread == true) {
                statusPane.setText("Interrupted - Classifier One Training Completed");
                dataset3OutputFile.close();
                return classifierOne;
            }
            lineCounter++;//Putting it here will mean if lineCounter is x then line == sequence x
            dataset3OutputFile.write(fastaFormat.getHeader());
            dataset3OutputFile.newLine();
            dataset3OutputFile.write(fastaFormat.getSequence());
            dataset3OutputFile.newLine();
            //if((lineCounter % 100) == 0){                                 
            statusPane.setText("Running Classifier One on Dataset 3.. @ " + lineCounter + " / "
                    + applicationData.getTotalSequences(3) + " Sequences");
            //}

            // for +1 index being -1, only make one prediction for the whole sequence             
            if (fastaFormat.getIndexLocation() == -1) {
                //Should not have reached here...
                dataset3OutputFile.close();
                throw new Exception("SHOULD NOT HAVE REACHED HERE!!");
            } else {// for +1 index being non -1, make prediction on every possible position
                    //For each sequence, you want to shift from predictPositionFrom till predictPositionTo
                    //ie changing the +1 location
                    //to get the scores given by classifier one so that 
                    //you can use it to train classifier two later
                    //Doing shift from predictPositionFrom till predictPositionTo                
                int predictPosition[];
                predictPosition = fastaFormat.getPredictPositionForClassifierOne(
                        applicationData.getLeftMostPosition(), applicationData.getRightMostPosition());

                SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(),
                        predictPosition[0], predictPosition[1]);
                String line2;
                int currentPosition = predictPosition[0];
                dataset3OutputFile.write(_class);
                while ((line2 = seq.nextShift()) != null) {
                    Instance tempInst;
                    tempInst = new Instance(inst.numAttributes());
                    tempInst.setDataset(inst);
                    for (int x = 0; x < inst.numAttributes() - 1; x++) {
                        //-1 because class attribute can be ignored
                        //Give the sequence and the featureList to get the feature freqs on the sequence
                        Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2,
                                featureDataArrayList.get(x), applicationData.getScoringMatrixIndex(),
                                applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix());
                        if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                            tempInst.setValue(x, (Integer) obj);
                        else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                            tempInst.setValue(x, (Double) obj);
                        else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                            tempInst.setValue(x, (String) obj);
                        else {
                            dataset3OutputFile.close();
                            throw new Error("Unknown: " + obj.getClass().getName());
                        }
                    }
                    tempInst.setValue(inst.numAttributes() - 1, _class);
                    double[] results = classifierOne.distributionForInstance(tempInst);
                    dataset3OutputFile.write("," + currentPosition + "=" + results[0]);
                    //AHFU_DEBUG 
                    /*if(currentPosition >= setClassifierTwoUpstreamInt && currentPosition <= setClassifierTwoDownstreamInt)
                       testClassifierTwoArff.write(results[0] + ",");*/
                    //AHFU_DEBUG_END
                    currentPosition++;
                    if (currentPosition == 0)
                        currentPosition++;
                } // end of while((line2 = seq.nextShift())!=null) 
                  //AHFU_DEBUG
                  /*testClassifierTwoArff.write(_class);
                  testClassifierTwoArff.newLine();
                  testClassifierTwoArff.flush();*/
                  //AHFU_DEBUG_END
                dataset3OutputFile.newLine();
                dataset3OutputFile.flush();
                if (lineCounter == totalDataset3PositiveInstances)
                    _class = "neg";
            } //end of inside non -1                                  
        } // end of while((fastaFormat = fastaFile.nextSequence(_class))!=null)       
        dataset3OutputFile.close();
        PredictionStats classifierOneStatsOnBlindTest = new PredictionStats(
                applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores", range,
                threshold);
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                        + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");
        classifierOneStatsOnBlindTest.updateDisplay(classifierResults, classifierOneDisplayTextArea, true);
        applicationData.setClassifierOneStats(classifierOneStatsOnBlindTest);
        myGraph.setMyStats(classifierOneStatsOnBlindTest);
        statusPane.setText("Done!");
        fastaFile.cleanUp();
        return classifierOne;
    } catch (Exception ex) {
        ex.printStackTrace();
        JOptionPane.showMessageDialog(parent, ex.getMessage() + "Classifier One on Blind Test Set",
                "Evaluate classifier", JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:sirius.trainer.step4.RunClassifier.java

License:Open Source License

public static Classifier startClassifierTwo(JInternalFrame parent, ApplicationData applicationData,
        JTextArea classifierTwoDisplayTextArea, GenericObjectEditor m_ClassifierEditor2,
        Classifier classifierOne, GraphPane myGraph, boolean test, ClassifierResults classifierResults,
        int range, double threshold) {
    int arraySize = 0;
    int lineCount = 0;
    try {//from  w  w  w  .j a  v a 2 s .c om
        StatusPane statusPane = applicationData.getStatusPane();
        //Initialising      
        long totalTimeStart = System.currentTimeMillis();
        Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();
        int positiveDataset3FromInt = applicationData.getPositiveDataset3FromField();
        int positiveDataset3ToInt = applicationData.getPositiveDataset3ToField();
        int negativeDataset3FromInt = applicationData.getNegativeDataset3FromField();
        int negativeDataset3ToInt = applicationData.getNegativeDataset3ToField();

        //Preparing Dataset2.arff to train Classifier Two
        statusPane.setText("Preparing Dataset2.arff...");
        //This step generates Dataset2.arff
        if (DatasetGenerator.generateDataset2(parent, applicationData, applicationData.getSetUpstream(),
                applicationData.getSetDownstream(), classifierOne) == false) {
            //Interrupted or Error occurred
            return null;
        }

        //Training Classifier Two
        statusPane.setText("Training Classifier Two... May take a while... Please wait...");
        Instances inst2 = new Instances(new BufferedReader(
                new FileReader(applicationData.getWorkingDirectory() + File.separator + "Dataset2.arff")));
        inst2.setClassIndex(inst2.numAttributes() - 1);
        long trainTimeStart = 0;
        long trainTimeElapsed = 0;

        Classifier classifierTwo = (Classifier) m_ClassifierEditor2.getValue();
        trainTimeStart = System.currentTimeMillis();
        applicationData.setDataset2Instances(inst2);
        classifierTwo.buildClassifier(inst2);
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;

        //Running Classifier Two   
        String classifierName = m_ClassifierEditor2.getValue().getClass().getName();
        classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName);
        classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                applicationData.getWorkingDirectory() + File.separator + "Dataset2.arff");
        classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");

        if (test == false) {
            statusPane.setText("Classifier Two Trained...Done...");
            return classifierTwo;
        }
        if (applicationData.terminateThread == true) {
            statusPane.setText("Interrupted - Classifier One Training Completed");
            return classifierTwo;
        }
        statusPane.setText("Running Classifier Two on Dataset 3...");

        //Generate the header for ClassifierTwo.scores on Dataset3            
        BufferedWriter classifierTwoOutput = new BufferedWriter(new FileWriter(
                applicationData.getWorkingDirectory() + File.separator + "ClassifierTwo.scores"));
        if (m_ClassifierEditor2.getValue() instanceof OptionHandler)
            classifierName += " "
                    + Utils.joinOptions(((OptionHandler) m_ClassifierEditor2.getValue()).getOptions());

        //Generating an Instance given a sequence with the current attributes
        int setClassifierTwoUpstreamInt = applicationData.getSetUpstream();
        int setClassifierTwoDownstreamInt = applicationData.getSetDownstream();
        int classifierTwoWindowSize;
        if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt > 0)
            classifierTwoWindowSize = (setClassifierTwoUpstreamInt * -1) + setClassifierTwoDownstreamInt;
        else if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt < 0)
            classifierTwoWindowSize = (setClassifierTwoUpstreamInt - setClassifierTwoDownstreamInt - 1) * -1;
        else//both +ve
            classifierTwoWindowSize = (setClassifierTwoDownstreamInt - setClassifierTwoUpstreamInt + 1);

        Instances inst = applicationData.getDataset1Instances();

        //NOTE: need to take care of this function;    
        FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel,
                negativeStep1TableModel, positiveDataset3FromInt, positiveDataset3ToInt,
                negativeDataset3FromInt, negativeDataset3ToInt, applicationData.getWorkingDirectory());

        //loading in all the features..
        ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>();
        for (int x = 0; x < inst.numAttributes() - 1; x++) {
            //-1 because class attribute must be ignored
            featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name()));
        }

        //Reading the fastaFile                                
        String _class = "pos";
        lineCount = 0;
        int totalPosSequences = positiveDataset3ToInt - positiveDataset3FromInt + 1;
        FastaFormat fastaFormat;
        while ((fastaFormat = fastaFile.nextSequence(_class)) != null) {
            if (applicationData.terminateThread == true) {
                statusPane.setText("Interrupted - Classifier Two Trained");
                classifierTwoOutput.close();
                return classifierTwo;
            }
            lineCount++;
            classifierTwoOutput.write(fastaFormat.getHeader());
            classifierTwoOutput.newLine();
            classifierTwoOutput.write(fastaFormat.getSequence());
            classifierTwoOutput.newLine();
            //if((lineCount % 100) == 0){                      
            statusPane.setText("Running ClassifierTwo on Dataset 3...@ " + lineCount + " / "
                    + applicationData.getTotalSequences(3) + " Sequences");
            //}
            arraySize = fastaFormat.getArraySize(applicationData.getLeftMostPosition(),
                    applicationData.getRightMostPosition());
            //This area always generate -ve arraySize~! WHY?? Exception always occur here              
            double scores[] = new double[arraySize];
            int predictPosition[] = fastaFormat.getPredictPositionForClassifierOne(
                    applicationData.getLeftMostPosition(), applicationData.getRightMostPosition());
            //Doing shift from upstream till downstream   
            SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(), predictPosition[0],
                    predictPosition[1]);
            int scoreCount = 0;
            String line2;
            while ((line2 = seq.nextShift()) != null) {
                Instance tempInst = new Instance(inst.numAttributes());
                tempInst.setDataset(inst);
                //-1 because class attribute can be ignored
                for (int x = 0; x < inst.numAttributes() - 1; x++) {
                    Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2,
                            featureDataArrayList.get(x), applicationData.getScoringMatrixIndex(),
                            applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix());
                    if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                        tempInst.setValue(x, (Integer) obj);
                    else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                        tempInst.setValue(x, (Double) obj);
                    else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                        tempInst.setValue(x, (String) obj);
                    else {
                        classifierTwoOutput.close();
                        throw new Error("Unknown: " + obj.getClass().getName());
                    }
                }
                tempInst.setValue(inst.numAttributes() - 1, _class);
                //Run classifierOne                 
                double[] results = classifierOne.distributionForInstance(tempInst);
                scores[scoreCount++] = results[0];
            }
            //Run classifierTwo                 
            int currentPosition = fastaFormat.getPredictionFromForClassifierTwo(
                    applicationData.getLeftMostPosition(), applicationData.getRightMostPosition(),
                    applicationData.getSetUpstream());
            classifierTwoOutput.write(_class);
            for (int y = 0; y < arraySize - classifierTwoWindowSize + 1; y++) {
                //+1 is for the class index
                Instance tempInst2 = new Instance(classifierTwoWindowSize + 1);
                tempInst2.setDataset(inst2);
                for (int x = 0; x < classifierTwoWindowSize; x++) {
                    tempInst2.setValue(x, scores[x + y]);
                }
                tempInst2.setValue(tempInst2.numAttributes() - 1, _class);
                double[] results = classifierTwo.distributionForInstance(tempInst2);
                classifierTwoOutput.write("," + currentPosition + "=" + results[0]);
                currentPosition++;
                if (currentPosition == 0)
                    currentPosition++;
            }
            classifierTwoOutput.newLine();
            classifierTwoOutput.flush();
            if (lineCount == totalPosSequences)
                _class = "neg";
        }
        classifierTwoOutput.close();
        statusPane.setText("Done!");
        PredictionStats classifierTwoStatsOnBlindTest = new PredictionStats(
                applicationData.getWorkingDirectory() + File.separator + "ClassifierTwo.scores", range,
                threshold);
        //display(double range)
        long totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                        + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");
        classifierTwoStatsOnBlindTest.updateDisplay(classifierResults, classifierTwoDisplayTextArea, true);
        applicationData.setClassifierTwoStats(classifierTwoStatsOnBlindTest);
        myGraph.setMyStats(classifierTwoStatsOnBlindTest);
        fastaFile.cleanUp();
        return classifierTwo;
    } catch (Exception ex) {
        ex.printStackTrace();
        JOptionPane.showMessageDialog(parent,
                ex.getMessage() + "Classifier Two On Blind Test Set - Check Console Output",
                "Evaluate classifier two", JOptionPane.ERROR_MESSAGE);
        System.err.println("applicationData.getLeftMostPosition(): " + applicationData.getLeftMostPosition());
        System.err.println("applicationData.getRightMostPosition(): " + applicationData.getRightMostPosition());
        System.err.println("arraySize: " + arraySize);
        System.err.println("lineCount: " + lineCount);
        return null;
    }
}

From source file:sirius.trainer.step4.RunClassifier.java

License:Open Source License

public static Classifier xValidateClassifierOne(JInternalFrame parent, ApplicationData applicationData,
        JTextArea classifierOneDisplayTextArea, GenericObjectEditor m_ClassifierEditor, int folds,
        GraphPane myGraph, ClassifierResults classifierResults, int range, double threshold,
        boolean outputClassifier) {
    try {//from www . j  a  v a 2 s  . c o m
        StatusPane statusPane = applicationData.getStatusPane();

        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        //Classifier tempClassifier = (Classifier) m_ClassifierEditor.getValue();
        int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField();
        int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField();
        int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField();
        int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField();

        Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();

        Instances inst = new Instances(applicationData.getDataset1Instances());
        inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);

        //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy                    
        long trainTimeStart = 0, trainTimeElapsed = 0;
        Classifier classifierOne = (Classifier) m_ClassifierEditor.getValue();
        statusPane.setText("Training Classifier One... May take a while... Please wait...");
        //Record Start Time
        trainTimeStart = System.currentTimeMillis();
        inst.deleteAttributeType(Attribute.STRING);
        if (outputClassifier)
            classifierOne.buildClassifier(inst);
        //Record Total Time used to build classifier one
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
        //Training Done                        

        String classifierName = m_ClassifierEditor.getValue().getClass().getName();
        classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName);
        classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                folds + " fold cross-validation on Dataset1.arff");
        classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");

        //Reading and Storing the featureList
        ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>();
        for (int y = 0; y < inst.numAttributes() - 1; y++) {
            featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(y).name()));
        }

        BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(
                applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores"));

        for (int x = 0; x < folds; x++) {
            File trainFile = new File(applicationData.getWorkingDirectory() + File.separator
                    + "trainingDataset1_" + (x + 1) + ".arff");
            File testFile = new File(applicationData.getWorkingDirectory() + File.separator + "testingDataset1_"
                    + (x + 1) + ".fasta");
            //AHFU_DEBUG
            //Generate also the training file in fasta format for debugging purpose
            File trainFileFasta = new File(applicationData.getWorkingDirectory() + File.separator
                    + "trainingDataset1_" + (x + 1) + ".fasta");
            //AHFU_DEBUG_END

            //AHFU_DEBUG - This part is to generate the TestClassifierTwo.arff for use in WEKA to test classifierTwo
            //TestClassifierTwo.arff - predictions scores from Set Upstream Field to Set Downstream Field
            //Now first generate the header for TestClassifierTwo.arff
            BufferedWriter testClassifierTwoArff = new BufferedWriter(
                    new FileWriter(applicationData.getWorkingDirectory() + File.separator + "TestClassifierTwo_"
                            + (x + 1) + ".arff"));
            int setClassifierTwoUpstreamInt = -40;
            int setClassifierTwoDownstreamInt = 41;
            testClassifierTwoArff.write("@relation \'Used to Test Classifier Two\'");
            testClassifierTwoArff.newLine();
            for (int d = setClassifierTwoUpstreamInt; d <= setClassifierTwoDownstreamInt; d++) {
                if (d == 0)
                    continue;
                testClassifierTwoArff.write("@attribute (" + d + ") numeric");
                testClassifierTwoArff.newLine();
            }
            if (positiveDataset1FromInt > 0 && negativeDataset1FromInt > 0)
                testClassifierTwoArff.write("@attribute Class {pos,neg}");
            else if (positiveDataset1FromInt > 0 && negativeDataset1FromInt == 0)
                testClassifierTwoArff.write("@attribute Class {pos}");
            else if (positiveDataset1FromInt == 0 && negativeDataset1FromInt > 0)
                testClassifierTwoArff.write("@attribute Class {neg}");
            testClassifierTwoArff.newLine();
            testClassifierTwoArff.newLine();
            testClassifierTwoArff.write("@data");
            testClassifierTwoArff.newLine();
            testClassifierTwoArff.newLine();
            //END of AHFU_DEBUG
            statusPane.setText("Building Fold " + (x + 1) + "...");
            FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel,
                    negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt,
                    negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory());

            //1) generate trainingDatasetX.arff headings
            BufferedWriter trainingOutputFile = new BufferedWriter(
                    new FileWriter(applicationData.getWorkingDirectory() + File.separator + "trainingDataset1_"
                            + (x + 1) + ".arff"));
            trainingOutputFile.write("@relation 'A temp file for X-validation purpose' ");
            trainingOutputFile.newLine();
            trainingOutputFile.newLine();
            trainingOutputFile.flush();

            for (int y = 0; y < inst.numAttributes() - 1; y++) {
                if (inst.attribute(y).type() == Attribute.NUMERIC)
                    trainingOutputFile.write("@attribute " + inst.attribute(y).name() + " numeric");
                else if (inst.attribute(y).type() == Attribute.STRING)
                    trainingOutputFile.write("@attribute " + inst.attribute(y).name() + " String");
                else {
                    testClassifierTwoArff.close();
                    outputCrossValidation.close();
                    trainingOutputFile.close();
                    throw new Error("Unknown type: " + inst.attribute(y).name());
                }
                trainingOutputFile.newLine();
                trainingOutputFile.flush();
            }
            if (positiveDataset1FromInt > 0 && negativeDataset1FromInt > 0)
                trainingOutputFile.write("@attribute Class {pos,neg}");
            else if (positiveDataset1FromInt > 0 && negativeDataset1FromInt == 0)
                trainingOutputFile.write("@attribute Class {pos}");
            else if (positiveDataset1FromInt == 0 && negativeDataset1FromInt > 0)
                trainingOutputFile.write("@attribute Class {neg}");
            trainingOutputFile.newLine();
            trainingOutputFile.newLine();
            trainingOutputFile.write("@data");
            trainingOutputFile.newLine();
            trainingOutputFile.newLine();
            trainingOutputFile.flush();

            //2) generate testingDataset1.fasta
            BufferedWriter testingOutputFile = new BufferedWriter(
                    new FileWriter(applicationData.getWorkingDirectory() + File.separator + "testingDataset1_"
                            + (x + 1) + ".fasta"));

            //AHFU_DEBUG
            //Open the IOStream for training file (fasta format)
            BufferedWriter trainingOutputFileFasta = new BufferedWriter(
                    new FileWriter(applicationData.getWorkingDirectory() + File.separator + "trainingDataset1_"
                            + (x + 1) + ".fasta"));
            //AHFU_DEBUG_END

            //Now, populating data for both the training and testing files            
            int fastaFileLineCounter = 0;
            int posTestSequenceCounter = 0;
            int totalTestSequenceCounter = 0;
            //For pos sequences   
            FastaFormat fastaFormat;
            while ((fastaFormat = fastaFile.nextSequence("pos")) != null) {
                if ((fastaFileLineCounter % folds) == x) {//This sequence for testing
                    testingOutputFile.write(fastaFormat.getHeader());
                    testingOutputFile.newLine();
                    testingOutputFile.write(fastaFormat.getSequence());
                    testingOutputFile.newLine();
                    testingOutputFile.flush();
                    posTestSequenceCounter++;
                    totalTestSequenceCounter++;
                } else {//for training
                    for (int z = 0; z < inst.numAttributes() - 1; z++) {
                        trainingOutputFile.write(GenerateArff.getMatchCount(fastaFormat,
                                featureDataArrayList.get(z), applicationData.getScoringMatrixIndex(),
                                applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix())
                                + ",");
                    }
                    trainingOutputFile.write("pos");
                    trainingOutputFile.newLine();
                    trainingOutputFile.flush();

                    //AHFU_DEBUG
                    //Write the datas into the training file in fasta format
                    trainingOutputFileFasta.write(fastaFormat.getHeader());
                    trainingOutputFileFasta.newLine();
                    trainingOutputFileFasta.write(fastaFormat.getSequence());
                    trainingOutputFileFasta.newLine();
                    trainingOutputFileFasta.flush();
                    //AHFU_DEBUG_END
                }
                fastaFileLineCounter++;
            }
            //For neg sequences
            fastaFileLineCounter = 0;
            while ((fastaFormat = fastaFile.nextSequence("neg")) != null) {
                if ((fastaFileLineCounter % folds) == x) {//This sequence for testing
                    testingOutputFile.write(fastaFormat.getHeader());
                    testingOutputFile.newLine();
                    testingOutputFile.write(fastaFormat.getSequence());
                    testingOutputFile.newLine();
                    testingOutputFile.flush();
                    totalTestSequenceCounter++;
                } else {//for training
                    for (int z = 0; z < inst.numAttributes() - 1; z++) {
                        trainingOutputFile.write(GenerateArff.getMatchCount(fastaFormat,
                                featureDataArrayList.get(z), applicationData.getScoringMatrixIndex(),
                                applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix())
                                + ",");
                    }
                    trainingOutputFile.write("neg");
                    trainingOutputFile.newLine();
                    trainingOutputFile.flush();

                    //AHFU_DEBUG
                    //Write the datas into the training file in fasta format
                    trainingOutputFileFasta.write(fastaFormat.getHeader());
                    trainingOutputFileFasta.newLine();
                    trainingOutputFileFasta.write(fastaFormat.getSequence());
                    trainingOutputFileFasta.newLine();
                    trainingOutputFileFasta.flush();
                    //AHFU_DEBUG_END
                }
                fastaFileLineCounter++;
            }
            trainingOutputFileFasta.close();
            trainingOutputFile.close();
            testingOutputFile.close();
            //3) train and test the classifier then store the statistics              
            Classifier foldClassifier = (Classifier) m_ClassifierEditor.getValue();
            Instances instFoldTrain = new Instances(
                    new BufferedReader(new FileReader(applicationData.getWorkingDirectory() + File.separator
                            + "trainingDataset1_" + (x + 1) + ".arff")));
            instFoldTrain.setClassIndex(instFoldTrain.numAttributes() - 1);
            foldClassifier.buildClassifier(instFoldTrain);

            //Reading the test file
            statusPane.setText("Evaluating fold " + (x + 1) + "..");
            BufferedReader testingInput = new BufferedReader(
                    new FileReader(applicationData.getWorkingDirectory() + File.separator + "testingDataset1_"
                            + (x + 1) + ".fasta"));
            int lineCounter = 0;
            String lineHeader;
            String lineSequence;
            while ((lineHeader = testingInput.readLine()) != null) {
                if (applicationData.terminateThread == true) {
                    statusPane.setText("Interrupted - Classifier One Training Completed");
                    testingInput.close();
                    testClassifierTwoArff.close();
                    return classifierOne;
                }
                lineSequence = testingInput.readLine();
                outputCrossValidation.write(lineHeader);
                outputCrossValidation.newLine();
                outputCrossValidation.write(lineSequence);
                outputCrossValidation.newLine();
                lineCounter++;
                //For each sequence, you want to shift from upstream till downstream 
                //ie changing the +1 location
                //to get the scores by classifier one so that can use it to train classifier two later
                //Doing shift from upstream till downstream    
                //if(lineCounter % 100 == 0)
                statusPane.setText("Evaluating fold " + (x + 1) + ".. @ " + lineCounter + " / "
                        + totalTestSequenceCounter);

                fastaFormat = new FastaFormat(lineHeader, lineSequence);
                int predictPosition[] = fastaFormat.getPredictPositionForClassifierOne(
                        applicationData.getLeftMostPosition(), applicationData.getRightMostPosition());

                SequenceManipulation seq = new SequenceManipulation(lineSequence, predictPosition[0],
                        predictPosition[1]);
                int currentPosition = predictPosition[0];
                String line2;
                if (lineCounter > posTestSequenceCounter)
                    outputCrossValidation.write("neg");
                else
                    outputCrossValidation.write("pos");
                while ((line2 = seq.nextShift()) != null) {
                    Instance tempInst;
                    tempInst = new Instance(inst.numAttributes());
                    tempInst.setDataset(inst);
                    for (int i = 0; i < inst.numAttributes() - 1; i++) {
                        //-1 because class attribute can be ignored
                        //Give the sequence and the featureList to get the feature freqs on the sequence
                        Object obj = GenerateArff.getMatchCount(lineHeader, line2, featureDataArrayList.get(i),
                                applicationData.getScoringMatrixIndex(),
                                applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix());
                        if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                            tempInst.setValue(x, (Integer) obj);
                        else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                            tempInst.setValue(x, (Double) obj);
                        else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                            tempInst.setValue(x, (String) obj);
                        else {
                            testingInput.close();
                            testClassifierTwoArff.close();
                            outputCrossValidation.close();
                            throw new Error("Unknown: " + obj.getClass().getName());
                        }
                    }
                    if (lineCounter > posTestSequenceCounter)
                        tempInst.setValue(inst.numAttributes() - 1, "neg");
                    else
                        tempInst.setValue(inst.numAttributes() - 1, "pos");
                    double[] results = foldClassifier.distributionForInstance(tempInst);
                    outputCrossValidation.write("," + currentPosition + "=" + results[0]);
                    //AHFU_DEBUG 
                    double[] resultsDebug = classifierOne.distributionForInstance(tempInst);
                    if (currentPosition >= setClassifierTwoUpstreamInt
                            && currentPosition <= setClassifierTwoDownstreamInt)
                        testClassifierTwoArff.write(resultsDebug[0] + ",");
                    //AHFU_DEBUG_END
                    currentPosition++;
                    if (currentPosition == 0)
                        currentPosition++;
                } //end of sequence shift                               
                outputCrossValidation.newLine();
                outputCrossValidation.flush();
                //AHFU_DEBUG
                if (lineCounter > posTestSequenceCounter)
                    testClassifierTwoArff.write("neg");
                else
                    testClassifierTwoArff.write("pos");
                testClassifierTwoArff.newLine();
                testClassifierTwoArff.flush();
                //AHFU_DEBUG_END
            } //end of reading test file
            outputCrossValidation.close();
            testingInput.close();
            testClassifierTwoArff.close();
            fastaFile.cleanUp();

            //NORMAL MODE
            //trainFile.delete();
            //testFile.delete();
            //NORMAL MODE END
            //AHFU_DEBUG MODE
            //testClassifierTwoArff.close();            
            trainFile.deleteOnExit();
            testFile.deleteOnExit();
            trainFileFasta.deleteOnExit();
            //AHFU_DEBUG_MODE_END
        } //end of for loop for xvalidation

        PredictionStats classifierOneStatsOnXValidation = new PredictionStats(
                applicationData.getWorkingDirectory() + File.separator + "ClassifierOne.scores", range,
                threshold);
        //display(double range)
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                        + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");
        classifierOneStatsOnXValidation.updateDisplay(classifierResults, classifierOneDisplayTextArea, true);
        applicationData.setClassifierOneStats(classifierOneStatsOnXValidation);
        myGraph.setMyStats(classifierOneStatsOnXValidation);

        statusPane.setText("Done!");

        return classifierOne;
    } catch (Exception e) {
        e.printStackTrace();
        JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:sirius.trainer.step4.RunClassifier.java

License:Open Source License

public static Classifier xValidateClassifierTwo(JInternalFrame parent, ApplicationData applicationData,
        JTextArea classifierTwoDisplayTextArea, GenericObjectEditor m_ClassifierEditor2,
        Classifier classifierOne, int folds, GraphPane myGraph, ClassifierResults classifierResults, int range,
        double threshold, boolean outputClassifier) {
    try {/*from   w w w.  j  ava  2 s .c o  m*/
        StatusPane statusPane = applicationData.getStatusPane();

        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        //Classifier tempClassifier = (Classifier) m_ClassifierEditor2.getValue();
        final int positiveDataset2FromInt = applicationData.getPositiveDataset2FromField();
        final int positiveDataset2ToInt = applicationData.getPositiveDataset2ToField();
        final int negativeDataset2FromInt = applicationData.getNegativeDataset2FromField();
        final int negativeDataset2ToInt = applicationData.getNegativeDataset2ToField();

        final int totalDataset2Sequences = (positiveDataset2ToInt - positiveDataset2FromInt + 1)
                + (negativeDataset2ToInt - negativeDataset2FromInt + 1);

        final int classifierTwoUpstream = applicationData.getSetUpstream();
        final int classifierTwoDownstream = applicationData.getSetDownstream();

        Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();

        //Train classifier two with the full dataset first then do cross-validation to gauge its accuracy                      
        //Preparing Dataset2.arff to train Classifier Two
        long trainTimeStart = 0, trainTimeElapsed = 0;
        statusPane.setText("Preparing Dataset2.arff...");
        //This step generates Dataset2.arff
        if (DatasetGenerator.generateDataset2(parent, applicationData, applicationData.getSetUpstream(),
                applicationData.getSetDownstream(), classifierOne) == false) {
            //Interrupted or Error occurred
            return null;
        }
        Instances instOfDataset2 = new Instances(new BufferedReader(
                new FileReader(applicationData.getWorkingDirectory() + File.separator + "Dataset2.arff")));
        instOfDataset2.setClassIndex(instOfDataset2.numAttributes() - 1);
        applicationData.setDataset2Instances(instOfDataset2);
        Classifier classifierTwo = (Classifier) m_ClassifierEditor2.getValue();
        statusPane.setText("Training Classifier Two... May take a while... Please wait...");
        //Record Start Time
        trainTimeStart = System.currentTimeMillis();
        if (outputClassifier)
            classifierTwo.buildClassifier(instOfDataset2);
        //Record Total Time used to build classifier one
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
        //Training Done          

        String classifierName = m_ClassifierEditor2.getValue().getClass().getName();
        classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName);
        classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                folds + " fold cross-validation on Dataset2.arff");
        classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");

        Instances instOfDataset1 = new Instances(applicationData.getDataset1Instances());
        instOfDataset1.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);
        //Reading and Storing the featureList
        ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>();
        for (int y = 0; y < instOfDataset1.numAttributes() - 1; y++) {
            featureDataArrayList.add(Feature.levelOneClassifierPane(instOfDataset1.attribute(y).name()));
        }

        //Generating an Instance given a sequence with the current attributes
        int setClassifierTwoUpstreamInt = applicationData.getSetUpstream();
        int setClassifierTwoDownstreamInt = applicationData.getSetDownstream();
        int classifierTwoWindowSize;
        if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt > 0)
            classifierTwoWindowSize = (setClassifierTwoUpstreamInt * -1) + setClassifierTwoDownstreamInt;
        else if (setClassifierTwoUpstreamInt < 0 && setClassifierTwoDownstreamInt < 0)
            classifierTwoWindowSize = (setClassifierTwoUpstreamInt - setClassifierTwoDownstreamInt - 1) * -1;
        else//both +ve
            classifierTwoWindowSize = (setClassifierTwoDownstreamInt - setClassifierTwoUpstreamInt + 1);

        int posTestSequenceCounter = 0;

        BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(
                applicationData.getWorkingDirectory() + File.separator + "classifierTwo.scores"));

        for (int x = 0; x < folds; x++) {
            File trainFile = new File(applicationData.getWorkingDirectory() + File.separator
                    + "trainingDataset2_" + (x + 1) + ".arff");
            File testFile = new File(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_"
                    + (x + 1) + ".fasta");

            statusPane.setText("Preparing Training Data for Fold " + (x + 1) + "..");
            FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel,
                    negativeStep1TableModel, positiveDataset2FromInt, positiveDataset2ToInt,
                    negativeDataset2FromInt, negativeDataset2ToInt, applicationData.getWorkingDirectory());

            //1) generate trainingDataset2.arff headings
            BufferedWriter trainingOutputFile = new BufferedWriter(
                    new FileWriter(applicationData.getWorkingDirectory() + File.separator + "trainingDataset2_"
                            + (x + 1) + ".arff"));
            trainingOutputFile.write("@relation 'A temp file for X-validation purpose' ");
            trainingOutputFile.newLine();
            trainingOutputFile.newLine();
            trainingOutputFile.flush();
            for (int y = classifierTwoUpstream; y <= classifierTwoDownstream; y++) {
                if (y != 0) {
                    trainingOutputFile.write("@attribute (" + y + ") numeric");
                    trainingOutputFile.newLine();
                    trainingOutputFile.flush();
                }
            }
            if (positiveDataset2FromInt > 0 && negativeDataset2FromInt > 0)
                trainingOutputFile.write("@attribute Class {pos,neg}");
            else if (positiveDataset2FromInt > 0 && negativeDataset2FromInt == 0)
                trainingOutputFile.write("@attribute Class {pos}");
            else if (positiveDataset2FromInt == 0 && negativeDataset2FromInt > 0)
                trainingOutputFile.write("@attribute Class {neg}");
            trainingOutputFile.newLine();
            trainingOutputFile.newLine();
            trainingOutputFile.write("@data");
            trainingOutputFile.newLine();
            trainingOutputFile.newLine();
            trainingOutputFile.flush();
            //AHFU_DEBUG 
            BufferedWriter testingOutputFileArff = new BufferedWriter(
                    new FileWriter(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_"
                            + (x + 1) + ".arff"));
            testingOutputFileArff.write("@relation 'A temp file for X-validation purpose' ");
            testingOutputFileArff.newLine();
            testingOutputFileArff.newLine();
            testingOutputFileArff.flush();
            for (int y = classifierTwoUpstream; y <= classifierTwoDownstream; y++) {
                if (y != 0) {
                    testingOutputFileArff.write("@attribute (" + y + ") numeric");
                    testingOutputFileArff.newLine();
                    testingOutputFileArff.flush();
                }
            }
            if (positiveDataset2FromInt > 0 && negativeDataset2FromInt > 0)
                testingOutputFileArff.write("@attribute Class {pos,neg}");
            else if (positiveDataset2FromInt > 0 && negativeDataset2FromInt == 0)
                testingOutputFileArff.write("@attribute Class {pos}");
            else if (positiveDataset2FromInt == 0 && negativeDataset2FromInt > 0)
                testingOutputFileArff.write("@attribute Class {neg}");
            testingOutputFileArff.newLine();
            testingOutputFileArff.newLine();
            testingOutputFileArff.write("@data");
            testingOutputFileArff.newLine();
            testingOutputFileArff.newLine();
            testingOutputFileArff.flush();
            //AHFU_DEBUG END
            //2) generate testingDataset2.fasta
            BufferedWriter testingOutputFile = new BufferedWriter(
                    new FileWriter(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_"
                            + (x + 1) + ".fasta"));

            //Now, populating datas for both the training and testing files            
            int fastaFileLineCounter = 0;
            posTestSequenceCounter = 0;
            int totalTestSequenceCounter = 0;
            int totalTrainTestSequenceCounter = 0;
            FastaFormat fastaFormat;
            //For pos sequences   
            while ((fastaFormat = fastaFile.nextSequence("pos")) != null) {
                if (applicationData.terminateThread == true) {
                    statusPane.setText("Interrupted - Classifier Two Trained");
                    outputCrossValidation.close();
                    testingOutputFileArff.close();
                    testingOutputFile.close();
                    trainingOutputFile.close();
                    return classifierTwo;
                }
                totalTrainTestSequenceCounter++;
                //if(totalTrainTestSequenceCounter%100 == 0)
                statusPane.setText("Preparing Training Data for Fold " + (x + 1) + ".. @ "
                        + totalTrainTestSequenceCounter + " / " + totalDataset2Sequences);
                if ((fastaFileLineCounter % folds) == x) {//This sequence is for testing
                    testingOutputFile.write(fastaFormat.getHeader());
                    testingOutputFile.newLine();
                    testingOutputFile.write(fastaFormat.getSequence());
                    testingOutputFile.newLine();
                    testingOutputFile.flush();
                    posTestSequenceCounter++;
                    totalTestSequenceCounter++;
                    //AHFU DEBUG
                    SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(),
                            classifierTwoUpstream, classifierTwoDownstream);
                    String line2;
                    while ((line2 = seq.nextShift()) != null) {
                        Instance tempInst = new Instance(instOfDataset1.numAttributes());
                        tempInst.setDataset(instOfDataset1);
                        //-1 because class attribute can be ignored
                        for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) {
                            Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2,
                                    featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(),
                                    applicationData.getCountingStyleIndex(),
                                    applicationData.getScoringMatrix());
                            if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                                tempInst.setValue(w, (Integer) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                                tempInst.setValue(w, (Double) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                                tempInst.setValue(w, (String) obj);
                            else {
                                outputCrossValidation.close();
                                testingOutputFileArff.close();
                                testingOutputFile.close();
                                trainingOutputFile.close();
                                throw new Error("Unknown: " + obj.getClass().getName());
                            }
                        }
                        tempInst.setValue(tempInst.numAttributes() - 1, "pos");
                        double[] results = classifierOne.distributionForInstance(tempInst);
                        testingOutputFileArff.write(results[0] + ",");
                    }
                    testingOutputFileArff.write("pos");
                    testingOutputFileArff.newLine();
                    testingOutputFileArff.flush();
                    //AHFU DEBUG END
                } else {//This sequence is for training
                    SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(),
                            classifierTwoUpstream, classifierTwoDownstream);
                    String line2;
                    while ((line2 = seq.nextShift()) != null) {
                        Instance tempInst = new Instance(instOfDataset1.numAttributes());
                        tempInst.setDataset(instOfDataset1);
                        //-1 because class attribute can be ignored
                        for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) {
                            Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2,
                                    featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(),
                                    applicationData.getCountingStyleIndex(),
                                    applicationData.getScoringMatrix());
                            if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                                tempInst.setValue(w, (Integer) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                                tempInst.setValue(w, (Double) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                                tempInst.setValue(w, (String) obj);
                            else {
                                outputCrossValidation.close();
                                testingOutputFileArff.close();
                                testingOutputFile.close();
                                trainingOutputFile.close();
                                throw new Error("Unknown: " + obj.getClass().getName());
                            }
                        }
                        tempInst.setValue(tempInst.numAttributes() - 1, "pos");
                        double[] results = classifierOne.distributionForInstance(tempInst);
                        trainingOutputFile.write(results[0] + ",");
                    }
                    trainingOutputFile.write("pos");
                    trainingOutputFile.newLine();
                    trainingOutputFile.flush();
                }
                fastaFileLineCounter++;
            }
            //For neg sequences
            fastaFileLineCounter = 0;
            while ((fastaFormat = fastaFile.nextSequence("neg")) != null) {
                if (applicationData.terminateThread == true) {
                    statusPane.setText("Interrupted - Classifier Two Trained");
                    outputCrossValidation.close();
                    testingOutputFileArff.close();
                    testingOutputFile.close();
                    trainingOutputFile.close();
                    return classifierTwo;
                }
                totalTrainTestSequenceCounter++;
                //if(totalTrainTestSequenceCounter%100 == 0)
                statusPane.setText("Preparing Training Data for Fold " + (x + 1) + ".. @ "
                        + totalTrainTestSequenceCounter + " / " + totalDataset2Sequences);
                if ((fastaFileLineCounter % folds) == x) {//This sequence is for testing
                    testingOutputFile.write(fastaFormat.getHeader());
                    testingOutputFile.newLine();
                    testingOutputFile.write(fastaFormat.getSequence());
                    testingOutputFile.newLine();
                    testingOutputFile.flush();
                    totalTestSequenceCounter++;
                    //AHFU DEBUG
                    SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(),
                            classifierTwoUpstream, classifierTwoDownstream);
                    String line2;
                    while ((line2 = seq.nextShift()) != null) {
                        Instance tempInst = new Instance(instOfDataset1.numAttributes());
                        tempInst.setDataset(instOfDataset1);
                        //-1 because class attribute can be ignored
                        for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) {
                            Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2,
                                    featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(),
                                    applicationData.getCountingStyleIndex(),
                                    applicationData.getScoringMatrix());
                            if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                                tempInst.setValue(w, (Integer) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                                tempInst.setValue(w, (Double) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                                tempInst.setValue(w, (String) obj);
                            else {
                                outputCrossValidation.close();
                                testingOutputFileArff.close();
                                testingOutputFile.close();
                                trainingOutputFile.close();
                                throw new Error("Unknown: " + obj.getClass().getName());
                            }
                        }
                        tempInst.setValue(tempInst.numAttributes() - 1, "pos");//pos or neg does not matter here - not used         
                        double[] results = classifierOne.distributionForInstance(tempInst);
                        testingOutputFileArff.write(results[0] + ",");
                    }
                    testingOutputFileArff.write("neg");
                    testingOutputFileArff.newLine();
                    testingOutputFileArff.flush();
                    //AHFU DEBUG END
                } else {//This sequence is for training
                    SequenceManipulation seq = new SequenceManipulation(fastaFormat.getSequence(),
                            classifierTwoUpstream, classifierTwoDownstream);
                    String line2;
                    while ((line2 = seq.nextShift()) != null) {
                        Instance tempInst = new Instance(instOfDataset1.numAttributes());
                        tempInst.setDataset(instOfDataset1);
                        //-1 because class attribute can be ignored
                        for (int w = 0; w < instOfDataset1.numAttributes() - 1; w++) {
                            Object obj = GenerateArff.getMatchCount(fastaFormat.getHeader(), line2,
                                    featureDataArrayList.get(w), applicationData.getScoringMatrixIndex(),
                                    applicationData.getCountingStyleIndex(),
                                    applicationData.getScoringMatrix());
                            if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                                tempInst.setValue(w, (Integer) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                                tempInst.setValue(w, (Double) obj);
                            else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                                tempInst.setValue(w, (String) obj);
                            else {
                                outputCrossValidation.close();
                                testingOutputFileArff.close();
                                testingOutputFile.close();
                                trainingOutputFile.close();
                                throw new Error("Unknown: " + obj.getClass().getName());
                            }
                        }
                        tempInst.setValue(tempInst.numAttributes() - 1, "pos");//pos or neg does not matter here - not used              
                        double[] results = classifierOne.distributionForInstance(tempInst);
                        trainingOutputFile.write(results[0] + ",");
                    }
                    trainingOutputFile.write("neg");
                    trainingOutputFile.newLine();
                    trainingOutputFile.flush();
                }
                fastaFileLineCounter++;
            }
            trainingOutputFile.close();
            testingOutputFile.close();

            //AHFU_DEBUG
            testingOutputFileArff.close();
            //AHFU DEBUG END
            //3) train and test classifier two then store the statistics
            statusPane.setText("Building Fold " + (x + 1) + "..");
            //open an input stream to the arff file 
            BufferedReader trainingInput = new BufferedReader(
                    new FileReader(applicationData.getWorkingDirectory() + File.separator + "trainingDataset2_"
                            + (x + 1) + ".arff"));
            //getting ready to train a foldClassifier using arff file
            Instances instOfTrainingDataset2 = new Instances(
                    new BufferedReader(new FileReader(applicationData.getWorkingDirectory() + File.separator
                            + "trainingDataset2_" + (x + 1) + ".arff")));
            instOfTrainingDataset2.setClassIndex(instOfTrainingDataset2.numAttributes() - 1);
            Classifier foldClassifier = (Classifier) m_ClassifierEditor2.getValue();
            foldClassifier.buildClassifier(instOfTrainingDataset2);
            trainingInput.close();

            //Reading the test file
            statusPane.setText("Evaluating fold " + (x + 1) + "..");
            BufferedReader testingInput = new BufferedReader(
                    new FileReader(applicationData.getWorkingDirectory() + File.separator + "testingDataset2_"
                            + (x + 1) + ".fasta"));
            int lineCounter = 0;
            String lineHeader;
            String lineSequence;
            while ((lineHeader = testingInput.readLine()) != null) {
                if (applicationData.terminateThread == true) {
                    statusPane.setText("Interrupted - Classifier Two Not Trained");
                    outputCrossValidation.close();
                    testingOutputFileArff.close();
                    testingOutputFile.close();
                    trainingOutputFile.close();
                    testingInput.close();
                    return classifierTwo;
                }
                lineSequence = testingInput.readLine();
                outputCrossValidation.write(lineHeader);
                outputCrossValidation.newLine();
                outputCrossValidation.write(lineSequence);
                outputCrossValidation.newLine();
                lineCounter++;
                fastaFormat = new FastaFormat(lineHeader, lineSequence);
                int arraySize = fastaFormat.getArraySize(applicationData.getLeftMostPosition(),
                        applicationData.getRightMostPosition());
                double scores[] = new double[arraySize];
                int predictPosition[] = fastaFormat.getPredictPositionForClassifierOne(
                        applicationData.getLeftMostPosition(), applicationData.getRightMostPosition());
                //For each sequence, you want to shift from upstream till downstream 
                //ie changing the +1 location
                //to get the scores by classifier one so that can use it to train classifier two later
                //Doing shift from upstream till downstream    
                //if(lineCounter % 100 == 0)
                statusPane.setText("Evaluating fold " + (x + 1) + ".. @ " + lineCounter + " / "
                        + totalTestSequenceCounter);
                SequenceManipulation seq = new SequenceManipulation(lineSequence, predictPosition[0],
                        predictPosition[1]);
                int scoreCount = 0;
                String line2;
                while ((line2 = seq.nextShift()) != null) {
                    Instance tempInst = new Instance(instOfDataset1.numAttributes());
                    tempInst.setDataset(instOfDataset1);
                    for (int i = 0; i < instOfDataset1.numAttributes() - 1; i++) {
                        //-1 because class attribute can be ignored
                        //Give the sequence and the featureList to get the feature freqs on the sequence
                        Object obj = GenerateArff.getMatchCount(lineHeader, line2, featureDataArrayList.get(i),
                                applicationData.getScoringMatrixIndex(),
                                applicationData.getCountingStyleIndex(), applicationData.getScoringMatrix());
                        if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                            tempInst.setValue(i, (Integer) obj);
                        else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                            tempInst.setValue(i, (Double) obj);
                        else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                            tempInst.setValue(i, (String) obj);
                        else {
                            outputCrossValidation.close();
                            testingOutputFileArff.close();
                            testingOutputFile.close();
                            trainingOutputFile.close();
                            testingInput.close();
                            throw new Error("Unknown: " + obj.getClass().getName());
                        }
                    }
                    if (lineCounter > posTestSequenceCounter) {//for neg
                        tempInst.setValue(tempInst.numAttributes() - 1, "neg");
                    } else {
                        tempInst.setValue(tempInst.numAttributes() - 1, "pos");
                    }
                    double[] results = classifierOne.distributionForInstance(tempInst);
                    scores[scoreCount++] = results[0];
                } //end of sequence shift 
                  //Run classifierTwo                 
                int currentPosition = fastaFormat.getPredictionFromForClassifierTwo(
                        applicationData.getLeftMostPosition(), applicationData.getRightMostPosition(),
                        applicationData.getSetUpstream());
                if (lineCounter > posTestSequenceCounter)//neg
                    outputCrossValidation.write("neg");
                else
                    outputCrossValidation.write("pos");
                for (int y = 0; y < arraySize - classifierTwoWindowSize + 1; y++) {
                    //+1 is for the class index
                    Instance tempInst2 = new Instance(classifierTwoWindowSize + 1);
                    tempInst2.setDataset(instOfTrainingDataset2);
                    for (int l = 0; l < classifierTwoWindowSize; l++) {
                        tempInst2.setValue(l, scores[l + y]);
                    }
                    if (lineCounter > posTestSequenceCounter)//for neg
                        tempInst2.setValue(tempInst2.numAttributes() - 1, "neg");
                    else//for pos                          
                        tempInst2.setValue(tempInst2.numAttributes() - 1, "pos");
                    double[] results = foldClassifier.distributionForInstance(tempInst2);
                    outputCrossValidation.write("," + currentPosition + "=" + results[0]);
                    currentPosition++;
                    if (currentPosition == 0)
                        currentPosition++;
                }
                outputCrossValidation.newLine();
                outputCrossValidation.flush();
            } //end of reading test file
            outputCrossValidation.close();
            testingOutputFileArff.close();
            testingOutputFile.close();
            trainingOutputFile.close();
            testingInput.close();
            fastaFile.cleanUp();

            //AHFU_DEBUG
            trainFile.deleteOnExit();
            testFile.deleteOnExit();

            //NORMAL MODE
            //trainFile.delete();
            //testFile.delete();
        } //end of for loop for xvalidation      

        PredictionStats classifierTwoStatsOnXValidation = new PredictionStats(
                applicationData.getWorkingDirectory() + File.separator + "classifierTwo.scores", range,
                threshold);
        //display(double range)
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                        + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");
        classifierTwoStatsOnXValidation.updateDisplay(classifierResults, classifierTwoDisplayTextArea, true);
        applicationData.setClassifierTwoStats(classifierTwoStatsOnXValidation);
        myGraph.setMyStats(classifierTwoStatsOnXValidation);

        statusPane.setText("Done!");

        return classifierTwo;
    } catch (Exception e) {
        e.printStackTrace();
        JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java

License:Open Source License

public static Object startClassifierOneWithNoLocationIndex(JInternalFrame parent,
        ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, GraphPane myGraph,
        boolean test, ClassifierResults classifierResults, int range, double threshold, String classifierName,
        String[] classifierOptions, boolean returnClassifier, GeneticAlgorithmDialog gaDialog,
        int randomNumberForClassifier) {
    try {/*from www  .j a v a 2 s.co  m*/

        if (gaDialog != null) {
            //Run GA then load the result maxMCCFeatures into applicationData->Dataset1Instances
            int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField();
            int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField();
            int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField();
            int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField();
            FastaFileManipulation fastaFile = new FastaFileManipulation(
                    applicationData.getPositiveStep1TableModel(), applicationData.getNegativeStep1TableModel(),
                    positiveDataset1FromInt, positiveDataset1ToInt, negativeDataset1FromInt,
                    negativeDataset1ToInt, applicationData.getWorkingDirectory());
            FastaFormat fastaFormat;
            List<FastaFormat> posFastaList = new ArrayList<FastaFormat>();
            List<FastaFormat> negFastaList = new ArrayList<FastaFormat>();
            while ((fastaFormat = fastaFile.nextSequence("pos")) != null) {
                posFastaList.add(fastaFormat);
            }
            while ((fastaFormat = fastaFile.nextSequence("neg")) != null) {
                negFastaList.add(fastaFormat);
            }
            applicationData.setDataset1Instances(
                    runDAandLoadResult(applicationData, gaDialog, posFastaList, negFastaList));
        }

        StatusPane statusPane = applicationData.getStatusPane();
        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        //Setting up training data set 1 for classifier one      
        if (statusPane != null)
            statusPane.setText("Setting up...");
        //Load Dataset1 Instances
        Instances inst = new Instances(applicationData.getDataset1Instances());
        inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);
        applicationData.getDataset1Instances()
                .setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);
        // for recording of time
        long trainTimeStart = 0, trainTimeElapsed = 0;
        Classifier classifierOne = Classifier.forName(classifierName, classifierOptions);
        /*//Used to show the classifierName and options so that I can use them for qsub
        System.out.println(classifierName);
        String[] optionString = classifierOne.getOptions();
        for(int x = 0; x < optionString.length; x++)
           System.out.println(optionString[x]);*/
        if (statusPane != null)
            statusPane.setText("Training Classifier One... May take a while... Please wait...");
        //Record Start Time
        trainTimeStart = System.currentTimeMillis();
        //Train Classifier One            
        inst.deleteAttributeType(Attribute.STRING);
        classifierOne.buildClassifier(inst);
        //Record Total Time used to build classifier one
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;

        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName);
            classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                    applicationData.getWorkingDirectory() + File.separator + "Dataset1.arff");
            classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                    Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");
        }
        if (test == false) {
            //If Need Not Test option is selected
            if (statusPane != null)
                statusPane.setText("Done!");
            return classifierOne;
        }
        if (applicationData.terminateThread == true) {
            //If Stop button is pressed
            if (statusPane != null)
                statusPane.setText("Interrupted - Classifier One Training Completed");
            return classifierOne;
        }
        //Running classifier one on dataset3
        if (statusPane != null)
            statusPane.setText("Running ClassifierOne on Dataset 3..");
        int positiveDataset3FromInt = applicationData.getPositiveDataset3FromField();
        int positiveDataset3ToInt = applicationData.getPositiveDataset3ToField();
        int negativeDataset3FromInt = applicationData.getNegativeDataset3FromField();
        int negativeDataset3ToInt = applicationData.getNegativeDataset3ToField();

        //Generate the header for ClassifierOne.scores on Dataset3      
        String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_"
                + randomNumberForClassifier + ".scores";
        BufferedWriter dataset3OutputFile = new BufferedWriter(new FileWriter(classifierOneFilename));
        FastaFileManipulation fastaFile = new FastaFileManipulation(
                applicationData.getPositiveStep1TableModel(), applicationData.getNegativeStep1TableModel(),
                positiveDataset3FromInt, positiveDataset3ToInt, negativeDataset3FromInt, negativeDataset3ToInt,
                applicationData.getWorkingDirectory());

        //Reading and Storing the featureList
        ArrayList<Feature> featureDataArrayList = new ArrayList<Feature>();
        for (int x = 0; x < inst.numAttributes() - 1; x++) {
            //-1 because class attribute must be ignored
            featureDataArrayList.add(Feature.levelOneClassifierPane(inst.attribute(x).name()));
        }

        //Reading the fastaFile      
        int lineCounter = 0;
        String _class = "pos";
        int totalDataset3PositiveInstances = positiveDataset3ToInt - positiveDataset3FromInt + 1;
        FastaFormat fastaFormat;
        while ((fastaFormat = fastaFile.nextSequence(_class)) != null) {
            if (applicationData.terminateThread == true) {
                if (statusPane != null)
                    statusPane.setText("Interrupted - Classifier One Training Completed");
                dataset3OutputFile.close();
                return classifierOne;
            }
            dataset3OutputFile.write(fastaFormat.getHeader());
            dataset3OutputFile.newLine();
            dataset3OutputFile.write(fastaFormat.getSequence());
            dataset3OutputFile.newLine();
            lineCounter++;//Putting it here will mean if lineCounter is x then line == sequence x                              
            dataset3OutputFile.flush();
            if (statusPane != null)
                statusPane.setText("Running Classifier One on Dataset 3.. @ " + lineCounter + " / "
                        + applicationData.getTotalSequences(3) + " Sequences");
            Instance tempInst;
            tempInst = new Instance(inst.numAttributes());
            tempInst.setDataset(inst);
            for (int x = 0; x < inst.numAttributes() - 1; x++) {
                //-1 because class attribute can be ignored
                //Give the sequence and the featureList to get the feature freqs on the sequence
                Object obj = GenerateArff.getMatchCount(fastaFormat, featureDataArrayList.get(x),
                        applicationData.getScoringMatrixIndex(), applicationData.getCountingStyleIndex(),
                        applicationData.getScoringMatrix());
                if (obj.getClass().getName().equalsIgnoreCase("java.lang.Integer"))
                    tempInst.setValue(x, (Integer) obj);
                else if (obj.getClass().getName().equalsIgnoreCase("java.lang.Double"))
                    tempInst.setValue(x, (Double) obj);
                else if (obj.getClass().getName().equalsIgnoreCase("java.lang.String"))
                    tempInst.setValue(x, (String) obj);
                else {
                    dataset3OutputFile.close();
                    throw new Error("Unknown: " + obj.getClass().getName());
                }
            }
            tempInst.setValue(inst.numAttributes() - 1, _class);
            double[] results = classifierOne.distributionForInstance(tempInst);
            dataset3OutputFile.write(_class + ",0=" + results[0]);
            dataset3OutputFile.newLine();
            dataset3OutputFile.flush();
            if (lineCounter == totalDataset3PositiveInstances)
                _class = "neg";
        }
        dataset3OutputFile.close();

        //Display Statistics by reading the ClassifierOne.scores
        PredictionStats classifierOneStatsOnBlindTest = new PredictionStats(classifierOneFilename, range,
                threshold);
        //display(double range)
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                    Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                            + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");
            classifierOneStatsOnBlindTest.updateDisplay(classifierResults, classifierOneDisplayTextArea, true);
        } else
            classifierOneStatsOnBlindTest.updateDisplay(classifierResults, classifierOneDisplayTextArea, true);
        applicationData.setClassifierOneStats(classifierOneStatsOnBlindTest);
        if (myGraph != null)
            myGraph.setMyStats(classifierOneStatsOnBlindTest);
        if (statusPane != null)
            statusPane.setText("Done!");
        fastaFile.cleanUp();
        if (returnClassifier)
            return classifierOne;
        else
            return classifierOneStatsOnBlindTest;
    } catch (Exception ex) {
        ex.printStackTrace();
        JOptionPane.showMessageDialog(parent, ex.getMessage(), "Evaluate classifier",
                JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java

License:Open Source License

public static Classifier xValidateClassifierOneWithNoLocationIndex(JInternalFrame parent,
        ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, String classifierName,
        String[] classifierOptions, int folds, GraphPane myGraph, ClassifierResults classifierResults,
        int range, double threshold, boolean outputClassifier, GeneticAlgorithmDialog gaDialog,
        GASettingsInterface gaSettings, int randomNumberForClassifier) {
    try {/*w  ww .jav a2  s .  c  o  m*/
        StatusPane statusPane = applicationData.getStatusPane();
        if (statusPane == null)
            System.out.println("Null");
        //else
        //   stats

        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        Classifier tempClassifier = (Classifier) Classifier.forName(classifierName, classifierOptions);

        Instances inst = null;
        if (applicationData.getDataset1Instances() != null) {
            inst = new Instances(applicationData.getDataset1Instances());
            inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);
        }

        //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy   
        long trainTimeStart = 0, trainTimeElapsed = 0;
        Classifier classifierOne = (Classifier) Classifier.forName(classifierName, classifierOptions);
        if (statusPane != null)
            statusPane.setText("Training Classifier One... May take a while... Please wait...");
        //Record Start Time
        trainTimeStart = System.currentTimeMillis();
        if (outputClassifier && gaSettings == null)
            classifierOne.buildClassifier(inst);
        //Record Total Time used to build classifier one
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
        //Training Done      ]
        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName);
            classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                    folds + " fold cross-validation");
            classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                    Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");
        }
        int startRandomNumber;
        if (gaSettings != null)
            startRandomNumber = gaSettings.getRandomNumber();
        else
            startRandomNumber = 1;
        String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_"
                + randomNumberForClassifier + "_" + startRandomNumber + ".scores";
        BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename));

        Instances foldTrainingInstance = null;
        Instances foldTestingInstance = null;
        int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField();
        int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField();
        int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField();
        int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField();
        Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();
        FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel,
                negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt,
                negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory());
        FastaFormat fastaFormat;
        String header[] = null;
        String data[] = null;
        if (inst != null) {
            header = new String[inst.numInstances()];
            data = new String[inst.numInstances()];
        }
        List<FastaFormat> allPosList = new ArrayList<FastaFormat>();
        List<FastaFormat> allNegList = new ArrayList<FastaFormat>();
        int counter = 0;
        while ((fastaFormat = fastaFile.nextSequence("pos")) != null) {
            if (inst != null) {
                header[counter] = fastaFormat.getHeader();
                data[counter] = fastaFormat.getSequence();
                counter++;
            }
            allPosList.add(fastaFormat);
        }
        while ((fastaFormat = fastaFile.nextSequence("neg")) != null) {
            if (inst != null) {
                header[counter] = fastaFormat.getHeader();
                data[counter] = fastaFormat.getSequence();
                counter++;
            }
            allNegList.add(fastaFormat);
        }
        //run x folds
        for (int x = 0; x < folds; x++) {
            if (applicationData.terminateThread == true) {
                if (statusPane != null)
                    statusPane.setText("Interrupted - Classifier One Training Completed");
                outputCrossValidation.close();
                return classifierOne;
            }
            if (statusPane != null)
                statusPane.setPrefix("Running Fold " + (x + 1) + ": ");
            if (inst != null) {
                foldTrainingInstance = new Instances(inst, 0);
                foldTestingInstance = new Instances(inst, 0);
            }
            List<FastaFormat> trainPosList = new ArrayList<FastaFormat>();
            List<FastaFormat> trainNegList = new ArrayList<FastaFormat>();
            List<FastaFormat> testPosList = new ArrayList<FastaFormat>();
            List<FastaFormat> testNegList = new ArrayList<FastaFormat>();
            //split data into training and testing
            //This is for normal run
            int testInstanceIndex[] = null;
            if (inst != null)
                testInstanceIndex = new int[(inst.numInstances() / folds) + 1];
            if (gaSettings == null) {
                int testIndexCounter = 0;
                for (int y = 0; y < inst.numInstances(); y++) {
                    if ((y % folds) == x) {//this instance is for testing
                        foldTestingInstance.add(inst.instance(y));
                        testInstanceIndex[testIndexCounter] = y;
                        testIndexCounter++;
                    } else {//this instance is for training
                        foldTrainingInstance.add(inst.instance(y));
                    }
                }
            } else {
                //This is for GA run
                for (int y = 0; y < allPosList.size(); y++) {
                    if ((y % folds) == x) {//this instance is for testing
                        testPosList.add(allPosList.get(y));
                    } else {//this instance is for training
                        trainPosList.add(allPosList.get(y));
                    }
                }
                for (int y = 0; y < allNegList.size(); y++) {
                    if ((y % folds) == x) {//this instance is for testing
                        testNegList.add(allNegList.get(y));
                    } else {//this instance is for training
                        trainNegList.add(allNegList.get(y));
                    }
                }
                if (gaDialog != null)
                    foldTrainingInstance = runDAandLoadResult(applicationData, gaDialog, trainPosList,
                            trainNegList, x + 1, startRandomNumber);
                else
                    foldTrainingInstance = runDAandLoadResult(applicationData, gaSettings, trainPosList,
                            trainNegList, x + 1, startRandomNumber);
                foldTrainingInstance.setClassIndex(foldTrainingInstance.numAttributes() - 1);
                //Reading and Storing the featureList
                ArrayList<Feature> featureList = new ArrayList<Feature>();
                for (int y = 0; y < foldTrainingInstance.numAttributes() - 1; y++) {
                    //-1 because class attribute must be ignored
                    featureList.add(Feature.levelOneClassifierPane(foldTrainingInstance.attribute(y).name()));
                }
                String outputFilename;
                if (gaDialog != null)
                    outputFilename = gaDialog.getOutputLocation().getText() + File.separator
                            + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1)
                            + ".arff";
                else
                    outputFilename = gaSettings.getOutputLocation() + File.separator
                            + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1)
                            + ".arff";
                new GenerateFeatures(applicationData, featureList, testPosList, testNegList, outputFilename);
                foldTestingInstance = new Instances(new FileReader(outputFilename));
                foldTestingInstance.setClassIndex(foldTestingInstance.numAttributes() - 1);
            }

            Classifier foldClassifier = tempClassifier;
            foldClassifier.buildClassifier(foldTrainingInstance);
            for (int y = 0; y < foldTestingInstance.numInstances(); y++) {
                if (applicationData.terminateThread == true) {
                    if (statusPane != null)
                        statusPane.setText("Interrupted - Classifier One Training Completed");
                    outputCrossValidation.close();
                    return classifierOne;
                }
                double[] results = foldClassifier.distributionForInstance(foldTestingInstance.instance(y));
                int classIndex = foldTestingInstance.instance(y).classIndex();
                String classValue = foldTestingInstance.instance(y).toString(classIndex);
                if (inst != null) {
                    outputCrossValidation.write(header[testInstanceIndex[y]]);
                    outputCrossValidation.newLine();
                    outputCrossValidation.write(data[testInstanceIndex[y]]);
                    outputCrossValidation.newLine();
                } else {
                    if (y < testPosList.size()) {
                        outputCrossValidation.write(testPosList.get(y).getHeader());
                        outputCrossValidation.newLine();
                        outputCrossValidation.write(testPosList.get(y).getSequence());
                        outputCrossValidation.newLine();
                    } else {
                        outputCrossValidation.write(testNegList.get(y - testPosList.size()).getHeader());
                        outputCrossValidation.newLine();
                        outputCrossValidation.write(testNegList.get(y - testPosList.size()).getSequence());
                        outputCrossValidation.newLine();
                    }
                }
                if (classValue.equals("pos"))
                    outputCrossValidation.write("pos,0=" + results[0]);
                else if (classValue.equals("neg"))
                    outputCrossValidation.write("neg,0=" + results[0]);
                else {
                    outputCrossValidation.close();
                    throw new Error("Invalid Class Type!");
                }
                outputCrossValidation.newLine();
                outputCrossValidation.flush();
            }
        }
        outputCrossValidation.close();
        PredictionStats classifierOneStatsOnXValidation = new PredictionStats(classifierOneFilename, range,
                threshold);
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                    Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                            + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");
            classifierOneStatsOnXValidation.updateDisplay(classifierResults, classifierOneDisplayTextArea,
                    true);
        }
        applicationData.setClassifierOneStats(classifierOneStatsOnXValidation);
        if (myGraph != null)
            myGraph.setMyStats(classifierOneStatsOnXValidation);
        if (statusPane != null)
            statusPane.setText("Done!");
        //Note that this will be null if GA is run though maybe it is better if i run all sequence with GA and then build the classifier but this would be a waste of time
        return classifierOne;
    } catch (Exception e) {
        e.printStackTrace();
        JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java

License:Open Source License

public static Object jackKnifeClassifierOneWithNoLocationIndex(JInternalFrame parent,
        ApplicationData applicationData, JTextArea classifierOneDisplayTextArea,
        GenericObjectEditor m_ClassifierEditor, double ratio, GraphPane myGraph,
        ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier,
        String classifierName, String[] classifierOptions, boolean returnClassifier,
        int randomNumberForClassifier) {
    try {/*from  www .  j  a  v  a2s.  com*/
        StatusPane statusPane = applicationData.getStatusPane();

        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        Classifier tempClassifier;
        if (m_ClassifierEditor != null)
            tempClassifier = (Classifier) m_ClassifierEditor.getValue();
        else
            tempClassifier = Classifier.forName(classifierName, classifierOptions);

        //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files               
        //split the instances into positive and negative
        Instances posInst = new Instances(applicationData.getDataset1Instances());
        posInst.setClassIndex(posInst.numAttributes() - 1);
        for (int x = 0; x < posInst.numInstances();)
            if (posInst.instance(x).stringValue(posInst.numAttributes() - 1).equalsIgnoreCase("pos"))
                x++;
            else
                posInst.delete(x);
        posInst.deleteAttributeType(Attribute.STRING);
        Instances negInst = new Instances(applicationData.getDataset1Instances());
        negInst.setClassIndex(negInst.numAttributes() - 1);
        for (int x = 0; x < negInst.numInstances();)
            if (negInst.instance(x).stringValue(negInst.numAttributes() - 1).equalsIgnoreCase("neg"))
                x++;
            else
                negInst.delete(x);
        negInst.deleteAttributeType(Attribute.STRING);
        //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy   
        long trainTimeStart = 0, trainTimeElapsed = 0;
        if (statusPane != null)
            statusPane.setText("Training Classifier One... May take a while... Please wait...");
        //Record Start Time
        trainTimeStart = System.currentTimeMillis();
        Instances fullInst = new Instances(applicationData.getDataset1Instances());
        fullInst.setClassIndex(fullInst.numAttributes() - 1);
        Classifier classifierOne;
        if (m_ClassifierEditor != null)
            classifierOne = (Classifier) m_ClassifierEditor.getValue();
        else
            classifierOne = Classifier.forName(classifierName, classifierOptions);
        if (outputClassifier)
            classifierOne.buildClassifier(fullInst);
        //Record Total Time used to build classifier one
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
        //Training Done

        String tclassifierName;
        if (m_ClassifierEditor != null)
            tclassifierName = m_ClassifierEditor.getValue().getClass().getName();
        else
            tclassifierName = classifierName;
        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ",
                    tclassifierName);
            classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                    " Jack Knife Validation");
            classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                    Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");
        }
        String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_"
                + randomNumberForClassifier + ".scores";
        BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename));

        //Instances foldTrainingInstance;
        //Instances foldTestingInstance;
        int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField();
        int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField();
        int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField();
        int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField();
        Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();
        FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel,
                negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt,
                negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory());
        FastaFormat fastaFormat;
        String header[] = new String[fullInst.numInstances()];
        String data[] = new String[fullInst.numInstances()];
        int counter = 0;
        while ((fastaFormat = fastaFile.nextSequence("pos")) != null) {
            header[counter] = fastaFormat.getHeader();
            data[counter] = fastaFormat.getSequence();
            counter++;
        }
        while ((fastaFormat = fastaFile.nextSequence("neg")) != null) {
            header[counter] = fastaFormat.getHeader();
            data[counter] = fastaFormat.getSequence();
            counter++;
        }

        //run jack knife validation
        for (int x = 0; x < fullInst.numInstances(); x++) {
            if (applicationData.terminateThread == true) {
                if (statusPane != null)
                    statusPane.setText("Interrupted - Classifier One Training Completed");
                outputCrossValidation.close();
                return classifierOne;
            }
            if (statusPane != null)
                statusPane.setText("Running " + (x + 1) + " / " + fullInst.numInstances());
            Instances trainPosInst = new Instances(posInst);
            Instances trainNegInst = new Instances(negInst);
            Instance testInst;
            //split data into training and testing
            if (x < trainPosInst.numInstances()) {
                testInst = posInst.instance(x);
                trainPosInst.delete(x);
            } else {
                testInst = negInst.instance(x - posInst.numInstances());
                trainNegInst.delete(x - posInst.numInstances());
            }
            Instances trainInstances;
            if (trainPosInst.numInstances() < trainNegInst.numInstances()) {
                trainInstances = new Instances(trainPosInst);
                int max = (int) (ratio * trainPosInst.numInstances());
                if (ratio == -1)
                    max = trainNegInst.numInstances();
                Random rand = new Random(1);
                for (int y = 0; y < trainNegInst.numInstances() && y < max; y++) {
                    int index = rand.nextInt(trainNegInst.numInstances());
                    trainInstances.add(trainNegInst.instance(index));
                    trainNegInst.delete(index);
                }
            } else {
                trainInstances = new Instances(trainNegInst);
                int max = (int) (ratio * trainNegInst.numInstances());
                if (ratio == -1)
                    max = trainPosInst.numInstances();
                Random rand = new Random(1);
                for (int y = 0; y < trainPosInst.numInstances() && y < max; y++) {
                    int index = rand.nextInt(trainPosInst.numInstances());
                    trainInstances.add(trainPosInst.instance(index));
                    trainPosInst.delete(index);
                }
            }
            Classifier foldClassifier = tempClassifier;
            foldClassifier.buildClassifier(trainInstances);
            double[] results = foldClassifier.distributionForInstance(testInst);
            int classIndex = testInst.classIndex();
            String classValue = testInst.toString(classIndex);
            outputCrossValidation.write(header[x]);
            outputCrossValidation.newLine();
            outputCrossValidation.write(data[x]);
            outputCrossValidation.newLine();
            if (classValue.equals("pos"))
                outputCrossValidation.write("pos,0=" + results[0]);
            else if (classValue.equals("neg"))
                outputCrossValidation.write("neg,0=" + results[0]);
            else {
                outputCrossValidation.close();
                throw new Error("Invalid Class Type!");
            }
            outputCrossValidation.newLine();
            outputCrossValidation.flush();
        }
        outputCrossValidation.close();
        PredictionStats classifierOneStatsOnJackKnife = new PredictionStats(classifierOneFilename, range,
                threshold);
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        if (classifierResults != null)
            classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                    Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                            + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");

        //if(classifierOneDisplayTextArea != null)
        classifierOneStatsOnJackKnife.updateDisplay(classifierResults, classifierOneDisplayTextArea, true);
        applicationData.setClassifierOneStats(classifierOneStatsOnJackKnife);
        if (myGraph != null)
            myGraph.setMyStats(classifierOneStatsOnJackKnife);

        if (statusPane != null)
            statusPane.setText("Done!");
        if (returnClassifier)
            return classifierOne;
        else
            return classifierOneStatsOnJackKnife;
    } catch (Exception e) {
        e.printStackTrace();
        JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:SupervisedMetablocking.SupervisedCEP.java

License:Open Source License

@Override
protected void applyClassifier(Classifier classifier) throws Exception {
    for (AbstractBlock block : blocks) {
        ComparisonIterator iterator = block.getComparisonIterator();
        while (iterator.hasNext()) {
            Comparison comparison = iterator.next();
            final List<Integer> commonBlockIndices = entityIndex.getCommonBlockIndices(block.getBlockIndex(),
                    comparison);//www.  j  av  a 2  s .  co m
            if (commonBlockIndices == null) {
                continue;
            }

            if (trainingSet.contains(comparison)) {
                continue;
            }

            Instance currentInstance = getFeatures(NON_DUPLICATE, commonBlockIndices, comparison);
            double[] probabilities = classifier.distributionForInstance(currentInstance);
            if (probabilities[NON_DUPLICATE] < probabilities[DUPLICATE]) {
                comparison.setUtilityMeasure(probabilities[DUPLICATE]);
                addComparison(comparison);
            }
        }
    }
}