Example usage for weka.core Instances add

List of usage examples for weka.core Instances add

Introduction

In this page you can find the example usage for weka.core Instances add.

Prototype

@Override
public boolean add(Instance instance) 

Source Link

Document

Adds one instance to the end of the set.

Usage

From source file:shawn.gcbi.com.kea.main.KEAKeyphraseExtractor.java

License:Open Source License

/**
 * Builds the model from the files//w w  w.j ava  2 s  .  c o m
 */
public void extractKeyphrases(Hashtable stems) throws Exception {

    Vector stats = new Vector();

    // Check whether there is actually any data
    // = if there any files in the directory
    if (stems.size() == 0) {
        throw new Exception("Couldn't find any data!");
    }
    m_KEAFilter.setNumPhrases(m_numPhrases);
    m_KEAFilter.setVocabulary(m_vocabulary);
    m_KEAFilter.setVocabularyFormat(m_vocabularyFormat);
    m_KEAFilter.setDocumentLanguage(getDocumentLanguage());
    m_KEAFilter.setStemmer(m_Stemmer);
    m_KEAFilter.setStopwords(m_Stopwords);

    if (getVocabulary().equals("none")) {
        m_KEAFilter.m_NODEfeature = false;
    } else {
        m_KEAFilter.loadThesaurus(m_Stemmer, m_Stopwords);
    }

    FastVector atts = new FastVector(3);
    atts.addElement(new Attribute("doc", (FastVector) null));
    atts.addElement(new Attribute("keyphrases", (FastVector) null));
    atts.addElement(new Attribute("filename", (String) null));
    Instances data = new Instances("keyphrase_training_data", atts, 0);

    if (m_KEAFilter.m_Dictionary == null) {
        buildGlobalDictionaries(stems);
    }

    System.err.println("-- Extracting keyphrases... ");
    // Extract keyphrases
    Enumeration elem = stems.keys();
    // Enumeration over all files in the directory (now in the hash):
    while (elem.hasMoreElements()) {
        String str = (String) elem.nextElement();

        double[] newInst = new double[2];
        try {
            File txt = new File(m_dirName + "/" + str + ".txt");
            InputStreamReader is;
            if (!m_encoding.equals("default")) {
                is = new InputStreamReader(new FileInputStream(txt), m_encoding);
            } else {
                is = new InputStreamReader(new FileInputStream(txt));
            }
            StringBuffer txtStr = new StringBuffer();
            int c;
            while ((c = is.read()) != -1) {
                txtStr.append((char) c);
            }

            newInst[0] = (double) data.attribute(0).addStringValue(txtStr.toString());

        } catch (Exception e) {
            if (m_debug) {
                System.err.println("Can't read document " + str + ".txt");
            }
            newInst[0] = Instance.missingValue();
        }
        try {
            File key = new File(m_dirName + "/" + str + ".key");
            InputStreamReader is;
            if (!m_encoding.equals("default")) {
                is = new InputStreamReader(new FileInputStream(key), m_encoding);
            } else {
                is = new InputStreamReader(new FileInputStream(key));
            }
            StringBuffer keyStr = new StringBuffer();
            int c;

            // keyStr = keyphrases in the str.key file
            // Kea assumes, that these keyphrases were assigned by the author
            // and evaluates extracted keyphrases againse these

            while ((c = is.read()) != -1) {
                keyStr.append((char) c);
            }

            newInst[1] = (double) data.attribute(1).addStringValue(keyStr.toString());
        } catch (Exception e) {
            if (m_debug) {
                System.err.println("No existing keyphrases for stem " + str + ".");
            }
            newInst[1] = Instance.missingValue();
        }

        data.add(new Instance(1.0, newInst));

        m_KEAFilter.input(data.instance(0));

        data = data.stringFreeStructure();
        if (m_debug) {
            System.err.println("-- Document: " + str);
        }
        Instance[] topRankedInstances = new Instance[m_numPhrases];
        Instance inst;

        // Iterating over all extracted keyphrases (inst)
        while ((inst = m_KEAFilter.output()) != null) {

            int index = (int) inst.value(m_KEAFilter.getRankIndex()) - 1;

            if (index < m_numPhrases) {
                topRankedInstances[index] = inst;

            }
        }

        if (m_debug) {
            System.err.println("-- Keyphrases and feature values:");
        }
        FileOutputStream out = null;
        PrintWriter printer = null;
        File key = new File(m_dirName + "/" + str + ".key");
        if (!key.exists()) {
            out = new FileOutputStream(m_dirName + "/" + str + ".key");
            if (!m_encoding.equals("default")) {
                printer = new PrintWriter(new OutputStreamWriter(out, m_encoding));

            } else {
                printer = new PrintWriter(out);
            }
        }
        double numExtracted = 0, numCorrect = 0;

        for (int i = 0; i < m_numPhrases; i++) {
            if (topRankedInstances[i] != null) {
                if (!topRankedInstances[i].isMissing(topRankedInstances[i].numAttributes() - 1)) {
                    numExtracted += 1.0;
                }
                if ((int) topRankedInstances[i].value(topRankedInstances[i].numAttributes() - 1) == 1) {
                    numCorrect += 1.0;
                }
                if (printer != null) {
                    printer.print(topRankedInstances[i].stringValue(m_KEAFilter.getUnstemmedPhraseIndex()));
                    System.out.print(topRankedInstances[i].stringValue(m_KEAFilter.getUnstemmedPhraseIndex()));
                    System.out.println("\t" + Utils
                            .doubleToString(topRankedInstances[i].value(m_KEAFilter.getProbabilityIndex()), 4));

                    if (m_AdditionalInfo) {
                        printer.print("\t");
                        printer.print(topRankedInstances[i].stringValue(m_KEAFilter.getStemmedPhraseIndex()));
                        printer.print("\t");
                        printer.print(Utils.doubleToString(
                                topRankedInstances[i].value(m_KEAFilter.getProbabilityIndex()), 4));
                    }
                    printer.println();
                }
                if (m_debug) {
                    System.err.println(topRankedInstances[i]);
                }
            }
        }
        if (numExtracted > 0) {
            if (m_debug) {
                System.err.println("-- " + numCorrect + " correct");
            }
            stats.addElement(new Double(numCorrect));
        }
        if (printer != null) {
            printer.flush();
            printer.close();
            out.close();
        }
    }
    double[] st = new double[stats.size()];
    for (int i = 0; i < stats.size(); i++) {
        st[i] = ((Double) stats.elementAt(i)).doubleValue();
    }
    double avg = Utils.mean(st);
    double stdDev = Math.sqrt(Utils.variance(st));

    System.err.println("Avg. number of matching keyphrases compared to existing ones : "
            + Utils.doubleToString(avg, 2) + " +/- " + Utils.doubleToString(stdDev, 2));
    System.err.println("Based on " + stats.size() + " documents");
    // m_KEAFilter.batchFinished();
}

From source file:sim.app.ubik.behaviors.sharedservices.EMClustering.java

License:Open Source License

/**
 * Datos de entrenamiento,crea una instancia por cada persona que actualmente est usando 
 * un servicio con sus preferencias.//from   w  w w .j  a  va  2s.c o m
 * @return 
 */
private Instances generateTrainingData() {

    Instances ins = new Instances("usersProfile", attributes, 1000);
    for (SharedService ss : slist) {
        for (UserInterface ui : ss.getUsers()) {
            ins.add(getInstance(ui));
        }
    }
    return ins;
}

From source file:sirius.misc.zscore.ZscoreTableModel.java

License:Open Source License

public void siriusCorrelationFiltering(final double stdDevDist, final double maxOverlapPercent,
        final boolean includeNegatives) {
    Thread thread = new Thread() {
        public void run() {
            Instances instances = ZscoreTableModel.this.posInstances;
            if (includeNegatives)
                for (int x = 0; x < ZscoreTableModel.this.negInstances.numInstances(); x++)
                    instances.add(ZscoreTableModel.this.negInstances.instance(x));
            //for now, i will ignore the sign: as in, i would care only about the absolute change of stddev (ie. |stddev|)
            //use an O(a*a*n) algorithm where n = num of instances and a = num of attributes   
            MessageDialog m = new MessageDialog(null, "Progress", "0%");
            for (int a = 0; a < instances.numAttributes(); a++) {
                int indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index();
                if (instances.attribute(indexA).isNumeric() == false)
                    continue;
                //for each attribute pair, check for the num of overlap percent               
                double attibuteAStddev = instances.attributeStats(indexA).numericStats.stdDev;
                for (int b = a + 1; b < instances.numAttributes();) {
                    m.update(a + "/" + instances.numAttributes());
                    int indexB = instances.attribute(ZscoreTableModel.this.scoreList.get(b).getName()).index();
                    if (instances.attribute(indexB).isNumeric() == false) {
                        b++;//from  w  w  w .  ja  va2 s  . c om
                        continue;
                    }
                    int numOfOverlap = 0;
                    double attibuteBStddev = instances.attributeStats(indexB).numericStats.stdDev;
                    for (int x = 0; x < instances.numInstances() - 1; x++) {
                        //how do i consider an overlap?
                        //absolute difference from the previous instance is same in stddev
                        double attributeADifference = Math.abs(
                                ((instances.instance(x).value(indexA) - instances.instance(x + 1).value(indexA))
                                        / attibuteAStddev));
                        double attributeBDifference = Math.abs(
                                ((instances.instance(x).value(indexB) - instances.instance(x + 1).value(indexB))
                                        / attibuteBStddev));
                        if (Math.abs(attributeADifference - attributeBDifference) < stdDevDist)
                            numOfOverlap++;
                    }
                    double overlapPercent = (numOfOverlap * 100) / (instances.numInstances() - 1);
                    if (overlapPercent > maxOverlapPercent) {
                        ZscoreTableModel.this.posInstances.deleteAttributeAt(indexB);
                        ZscoreTableModel.this.negInstances.deleteAttributeAt(indexB);
                        ZscoreTableModel.this.scoreList.remove(b);
                        indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index();
                    } else
                        b++;
                }
            }
            m.dispose();
            ZscoreTableModel.this.label.setText("" + instances.numAttributes());
            //compute(ZscoreTableModel.this.posInstances,ZscoreTableModel.this.negInstances);
            ZscoreTableModel.this.fireTableDataChanged();
        }
    };
    thread.setPriority(Thread.MIN_PRIORITY); // UI has most priority
    thread.start();
}

From source file:sirius.misc.zscore.ZscoreTableModel.java

License:Open Source License

public void pearsonCorrelationFiltering(final double score, final boolean includeNegatives) {
    Thread thread = new Thread() {
        public void run() {
            Instances instances = ZscoreTableModel.this.posInstances;
            if (includeNegatives)
                for (int x = 0; x < ZscoreTableModel.this.negInstances.numInstances(); x++)
                    instances.add(ZscoreTableModel.this.negInstances.instance(x));
            //for now, i will ignore the sign: as in, i would care only about the absolute change of stddev (ie. |stddev|)
            //use an O(a*a*n) algorithm where n = num of instances and a = num of attributes
            MessageDialog m = new MessageDialog(null, "Progress", "0%");
            for (int a = 0; a < instances.numAttributes(); a++) {
                int indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index();
                if (instances.attribute(indexA).isNumeric() == false)
                    continue;
                //for each attribute pair, check for the num of overlap percent               
                double attributeAStddev = instances.attributeStats(indexA).numericStats.stdDev;
                double attributeAMean = instances.attributeStats(indexA).numericStats.mean;
                for (int b = a + 1; b < instances.numAttributes();) {
                    m.update(a + "/" + instances.numAttributes());
                    int indexB = instances.attribute(ZscoreTableModel.this.scoreList.get(b).getName()).index();
                    if (instances.attribute(indexB).isNumeric() == false) {
                        b++;//from  w  w  w  .j ava2 s.  c  om
                        continue;
                    }
                    double attributeBStddev = instances.attributeStats(indexB).numericStats.stdDev;
                    double attributeBMean = instances.attributeStats(indexB).numericStats.mean;
                    double nominator = 0.0;
                    for (int x = 0; x < instances.numInstances(); x++) {
                        nominator += ((instances.instance(x).value(indexA) - attributeAMean)
                                * (instances.instance(x).value(indexB) - attributeBMean));
                    }
                    double pScore = Math.abs(
                            nominator / ((instances.numInstances() - 1) * attributeAStddev * attributeBStddev));
                    if (pScore > score) {
                        ZscoreTableModel.this.posInstances.deleteAttributeAt(indexB);
                        ZscoreTableModel.this.negInstances.deleteAttributeAt(indexB);
                        ZscoreTableModel.this.scoreList.remove(b);
                        indexA = instances.attribute(ZscoreTableModel.this.scoreList.get(a).getName()).index();
                    } else
                        b++;
                }
            }
            m.dispose();
            ZscoreTableModel.this.label.setText("" + instances.numAttributes());
            //compute(ZscoreTableModel.this.posInstances,ZscoreTableModel.this.negInstances);
            ZscoreTableModel.this.fireTableDataChanged();
        }
    };
    thread.setPriority(Thread.MIN_PRIORITY); // UI has most priority
    thread.start();
}

From source file:sirius.misc.zscore.ZscoreTableModel.java

License:Open Source License

public void compute(final Instances posInstances, final Instances negInstances) {
    if (posInstances == null || negInstances == null) {
        JOptionPane.showMessageDialog(null, "Please load file before computing.", "Error",
                JOptionPane.ERROR_MESSAGE);
        return;// w w w. j  av  a 2s. c o  m
    }
    if (posInstances.numAttributes() != negInstances.numAttributes()) {
        JOptionPane.showMessageDialog(null, "Number of attributes between the two files does not tally.",
                "Error", JOptionPane.ERROR_MESSAGE);
        return;
    }
    this.scoreList = new ArrayList<Scores>();
    this.posInstances = posInstances;
    this.negInstances = negInstances;
    Thread thread = new Thread() {
        public void run() {
            MessageDialog m = new MessageDialog(null, "Progress", "0%");
            int percentCount = posInstances.numAttributes() / 100;
            if (percentCount == 0)
                percentCount = 1;
            for (int x = 0; x < posInstances.numAttributes(); x++) {
                if (x % percentCount == 0)
                    m.update(x / percentCount + "%");
                if (posInstances.attribute(x).isNumeric() == false) {
                    ZscoreTableModel.this.scoreList.add(new Scores(posInstances.attribute(x).name()));
                    continue;
                }
                String name = posInstances.attribute(x).name();
                double posMean = posInstances.attributeStats(x).numericStats.mean;
                double posStdDev = posInstances.attributeStats(x).numericStats.stdDev;
                double negMean = negInstances.attributeStats(x).numericStats.mean;
                double negStdDev = negInstances.attributeStats(x).numericStats.stdDev;
                if (negStdDev == 0)
                    negStdDev = 0.01;
                double totalZScore = 0.0;
                int numGTZScore0_5 = 0;
                int numGTZScore1 = 0;
                int numGTZScore2 = 0;
                int numGTZScore3 = 0;
                for (int y = 0; y < posInstances.numInstances(); y++) {
                    double zScore = Math.abs(((posInstances.instance(y).value(x) - negMean) / negStdDev));
                    totalZScore += zScore;
                    if (zScore > 0.5)
                        numGTZScore0_5++;
                    if (zScore > 1)
                        numGTZScore1++;
                    if (zScore > 2)
                        numGTZScore2++;
                    if (zScore > 3)
                        numGTZScore3++;
                }
                double meanZScore = totalZScore / posInstances.numInstances();
                double percentGTZScore0_5 = (numGTZScore0_5 * 100) / posInstances.numInstances();
                double percentGTZScore1 = (numGTZScore1 * 100) / posInstances.numInstances();
                double percentGTZScore2 = (numGTZScore2 * 100) / posInstances.numInstances();
                double percentGTZScore3 = (numGTZScore3 * 100) / posInstances.numInstances();
                ZscoreTableModel.this.scoreList
                        .add(new Scores(name, posMean, posStdDev, negMean, negStdDev, meanZScore,
                                percentGTZScore0_5, percentGTZScore1, percentGTZScore2, percentGTZScore3, -1));
            }
            try {
                Instances instances = new Instances(posInstances);
                for (int x = 0; x < negInstances.numInstances(); x++)
                    instances.add(negInstances.instance(x));
                instances.setClassIndex(instances.numAttributes() - 1);
                //Evaluate the attributes individually and obtain the gainRatio      
                GainRatioAttributeEval gainRatio = new GainRatioAttributeEval();
                if (instances.numAttributes() > 0) {
                    gainRatio.buildEvaluator(instances);
                }
                for (int x = 0; x < (instances.numAttributes() - 1); x++) {
                    ZscoreTableModel.this.scoreList.get(x).setGainRatio(gainRatio.evaluateAttribute(x));
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
            Collections.sort(ZscoreTableModel.this.scoreList, new SortByMeanZScore());
            fireTableDataChanged();
            m.dispose();
            ZscoreTableModel.this.label.setText("" + ZscoreTableModel.this.scoreList.size());
        }
    };
    thread.setPriority(Thread.MIN_PRIORITY); // UI has most priority
    thread.start();
}

From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java

License:Open Source License

public static Classifier xValidateClassifierOneWithNoLocationIndex(JInternalFrame parent,
        ApplicationData applicationData, JTextArea classifierOneDisplayTextArea, String classifierName,
        String[] classifierOptions, int folds, GraphPane myGraph, ClassifierResults classifierResults,
        int range, double threshold, boolean outputClassifier, GeneticAlgorithmDialog gaDialog,
        GASettingsInterface gaSettings, int randomNumberForClassifier) {
    try {// w  ww.jav  a  2  s . c o  m
        StatusPane statusPane = applicationData.getStatusPane();
        if (statusPane == null)
            System.out.println("Null");
        //else
        //   stats

        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        Classifier tempClassifier = (Classifier) Classifier.forName(classifierName, classifierOptions);

        Instances inst = null;
        if (applicationData.getDataset1Instances() != null) {
            inst = new Instances(applicationData.getDataset1Instances());
            inst.setClassIndex(applicationData.getDataset1Instances().numAttributes() - 1);
        }

        //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy   
        long trainTimeStart = 0, trainTimeElapsed = 0;
        Classifier classifierOne = (Classifier) Classifier.forName(classifierName, classifierOptions);
        if (statusPane != null)
            statusPane.setText("Training Classifier One... May take a while... Please wait...");
        //Record Start Time
        trainTimeStart = System.currentTimeMillis();
        if (outputClassifier && gaSettings == null)
            classifierOne.buildClassifier(inst);
        //Record Total Time used to build classifier one
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
        //Training Done      ]
        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ", classifierName);
            classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                    folds + " fold cross-validation");
            classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                    Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");
        }
        int startRandomNumber;
        if (gaSettings != null)
            startRandomNumber = gaSettings.getRandomNumber();
        else
            startRandomNumber = 1;
        String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_"
                + randomNumberForClassifier + "_" + startRandomNumber + ".scores";
        BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename));

        Instances foldTrainingInstance = null;
        Instances foldTestingInstance = null;
        int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField();
        int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField();
        int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField();
        int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField();
        Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();
        FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel,
                negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt,
                negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory());
        FastaFormat fastaFormat;
        String header[] = null;
        String data[] = null;
        if (inst != null) {
            header = new String[inst.numInstances()];
            data = new String[inst.numInstances()];
        }
        List<FastaFormat> allPosList = new ArrayList<FastaFormat>();
        List<FastaFormat> allNegList = new ArrayList<FastaFormat>();
        int counter = 0;
        while ((fastaFormat = fastaFile.nextSequence("pos")) != null) {
            if (inst != null) {
                header[counter] = fastaFormat.getHeader();
                data[counter] = fastaFormat.getSequence();
                counter++;
            }
            allPosList.add(fastaFormat);
        }
        while ((fastaFormat = fastaFile.nextSequence("neg")) != null) {
            if (inst != null) {
                header[counter] = fastaFormat.getHeader();
                data[counter] = fastaFormat.getSequence();
                counter++;
            }
            allNegList.add(fastaFormat);
        }
        //run x folds
        for (int x = 0; x < folds; x++) {
            if (applicationData.terminateThread == true) {
                if (statusPane != null)
                    statusPane.setText("Interrupted - Classifier One Training Completed");
                outputCrossValidation.close();
                return classifierOne;
            }
            if (statusPane != null)
                statusPane.setPrefix("Running Fold " + (x + 1) + ": ");
            if (inst != null) {
                foldTrainingInstance = new Instances(inst, 0);
                foldTestingInstance = new Instances(inst, 0);
            }
            List<FastaFormat> trainPosList = new ArrayList<FastaFormat>();
            List<FastaFormat> trainNegList = new ArrayList<FastaFormat>();
            List<FastaFormat> testPosList = new ArrayList<FastaFormat>();
            List<FastaFormat> testNegList = new ArrayList<FastaFormat>();
            //split data into training and testing
            //This is for normal run
            int testInstanceIndex[] = null;
            if (inst != null)
                testInstanceIndex = new int[(inst.numInstances() / folds) + 1];
            if (gaSettings == null) {
                int testIndexCounter = 0;
                for (int y = 0; y < inst.numInstances(); y++) {
                    if ((y % folds) == x) {//this instance is for testing
                        foldTestingInstance.add(inst.instance(y));
                        testInstanceIndex[testIndexCounter] = y;
                        testIndexCounter++;
                    } else {//this instance is for training
                        foldTrainingInstance.add(inst.instance(y));
                    }
                }
            } else {
                //This is for GA run
                for (int y = 0; y < allPosList.size(); y++) {
                    if ((y % folds) == x) {//this instance is for testing
                        testPosList.add(allPosList.get(y));
                    } else {//this instance is for training
                        trainPosList.add(allPosList.get(y));
                    }
                }
                for (int y = 0; y < allNegList.size(); y++) {
                    if ((y % folds) == x) {//this instance is for testing
                        testNegList.add(allNegList.get(y));
                    } else {//this instance is for training
                        trainNegList.add(allNegList.get(y));
                    }
                }
                if (gaDialog != null)
                    foldTrainingInstance = runDAandLoadResult(applicationData, gaDialog, trainPosList,
                            trainNegList, x + 1, startRandomNumber);
                else
                    foldTrainingInstance = runDAandLoadResult(applicationData, gaSettings, trainPosList,
                            trainNegList, x + 1, startRandomNumber);
                foldTrainingInstance.setClassIndex(foldTrainingInstance.numAttributes() - 1);
                //Reading and Storing the featureList
                ArrayList<Feature> featureList = new ArrayList<Feature>();
                for (int y = 0; y < foldTrainingInstance.numAttributes() - 1; y++) {
                    //-1 because class attribute must be ignored
                    featureList.add(Feature.levelOneClassifierPane(foldTrainingInstance.attribute(y).name()));
                }
                String outputFilename;
                if (gaDialog != null)
                    outputFilename = gaDialog.getOutputLocation().getText() + File.separator
                            + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1)
                            + ".arff";
                else
                    outputFilename = gaSettings.getOutputLocation() + File.separator
                            + "GeneticAlgorithmFeatureGenerationTest" + new Random().nextInt() + "_" + (x + 1)
                            + ".arff";
                new GenerateFeatures(applicationData, featureList, testPosList, testNegList, outputFilename);
                foldTestingInstance = new Instances(new FileReader(outputFilename));
                foldTestingInstance.setClassIndex(foldTestingInstance.numAttributes() - 1);
            }

            Classifier foldClassifier = tempClassifier;
            foldClassifier.buildClassifier(foldTrainingInstance);
            for (int y = 0; y < foldTestingInstance.numInstances(); y++) {
                if (applicationData.terminateThread == true) {
                    if (statusPane != null)
                        statusPane.setText("Interrupted - Classifier One Training Completed");
                    outputCrossValidation.close();
                    return classifierOne;
                }
                double[] results = foldClassifier.distributionForInstance(foldTestingInstance.instance(y));
                int classIndex = foldTestingInstance.instance(y).classIndex();
                String classValue = foldTestingInstance.instance(y).toString(classIndex);
                if (inst != null) {
                    outputCrossValidation.write(header[testInstanceIndex[y]]);
                    outputCrossValidation.newLine();
                    outputCrossValidation.write(data[testInstanceIndex[y]]);
                    outputCrossValidation.newLine();
                } else {
                    if (y < testPosList.size()) {
                        outputCrossValidation.write(testPosList.get(y).getHeader());
                        outputCrossValidation.newLine();
                        outputCrossValidation.write(testPosList.get(y).getSequence());
                        outputCrossValidation.newLine();
                    } else {
                        outputCrossValidation.write(testNegList.get(y - testPosList.size()).getHeader());
                        outputCrossValidation.newLine();
                        outputCrossValidation.write(testNegList.get(y - testPosList.size()).getSequence());
                        outputCrossValidation.newLine();
                    }
                }
                if (classValue.equals("pos"))
                    outputCrossValidation.write("pos,0=" + results[0]);
                else if (classValue.equals("neg"))
                    outputCrossValidation.write("neg,0=" + results[0]);
                else {
                    outputCrossValidation.close();
                    throw new Error("Invalid Class Type!");
                }
                outputCrossValidation.newLine();
                outputCrossValidation.flush();
            }
        }
        outputCrossValidation.close();
        PredictionStats classifierOneStatsOnXValidation = new PredictionStats(classifierOneFilename, range,
                threshold);
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                    Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                            + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");
            classifierOneStatsOnXValidation.updateDisplay(classifierResults, classifierOneDisplayTextArea,
                    true);
        }
        applicationData.setClassifierOneStats(classifierOneStatsOnXValidation);
        if (myGraph != null)
            myGraph.setMyStats(classifierOneStatsOnXValidation);
        if (statusPane != null)
            statusPane.setText("Done!");
        //Note that this will be null if GA is run though maybe it is better if i run all sequence with GA and then build the classifier but this would be a waste of time
        return classifierOne;
    } catch (Exception e) {
        e.printStackTrace();
        JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:sirius.trainer.step4.RunClassifierWithNoLocationIndex.java

License:Open Source License

public static Object jackKnifeClassifierOneWithNoLocationIndex(JInternalFrame parent,
        ApplicationData applicationData, JTextArea classifierOneDisplayTextArea,
        GenericObjectEditor m_ClassifierEditor, double ratio, GraphPane myGraph,
        ClassifierResults classifierResults, int range, double threshold, boolean outputClassifier,
        String classifierName, String[] classifierOptions, boolean returnClassifier,
        int randomNumberForClassifier) {
    try {// ww  w.j  a  v a  2s  .c om
        StatusPane statusPane = applicationData.getStatusPane();

        long totalTimeStart = System.currentTimeMillis(), totalTimeElapsed;
        Classifier tempClassifier;
        if (m_ClassifierEditor != null)
            tempClassifier = (Classifier) m_ClassifierEditor.getValue();
        else
            tempClassifier = Classifier.forName(classifierName, classifierOptions);

        //Assume that class attribute is the last attribute - This should be the case for all Sirius produced Arff files               
        //split the instances into positive and negative
        Instances posInst = new Instances(applicationData.getDataset1Instances());
        posInst.setClassIndex(posInst.numAttributes() - 1);
        for (int x = 0; x < posInst.numInstances();)
            if (posInst.instance(x).stringValue(posInst.numAttributes() - 1).equalsIgnoreCase("pos"))
                x++;
            else
                posInst.delete(x);
        posInst.deleteAttributeType(Attribute.STRING);
        Instances negInst = new Instances(applicationData.getDataset1Instances());
        negInst.setClassIndex(negInst.numAttributes() - 1);
        for (int x = 0; x < negInst.numInstances();)
            if (negInst.instance(x).stringValue(negInst.numAttributes() - 1).equalsIgnoreCase("neg"))
                x++;
            else
                negInst.delete(x);
        negInst.deleteAttributeType(Attribute.STRING);
        //Train classifier one with the full dataset first then do cross-validation to gauge its accuracy   
        long trainTimeStart = 0, trainTimeElapsed = 0;
        if (statusPane != null)
            statusPane.setText("Training Classifier One... May take a while... Please wait...");
        //Record Start Time
        trainTimeStart = System.currentTimeMillis();
        Instances fullInst = new Instances(applicationData.getDataset1Instances());
        fullInst.setClassIndex(fullInst.numAttributes() - 1);
        Classifier classifierOne;
        if (m_ClassifierEditor != null)
            classifierOne = (Classifier) m_ClassifierEditor.getValue();
        else
            classifierOne = Classifier.forName(classifierName, classifierOptions);
        if (outputClassifier)
            classifierOne.buildClassifier(fullInst);
        //Record Total Time used to build classifier one
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
        //Training Done

        String tclassifierName;
        if (m_ClassifierEditor != null)
            tclassifierName = m_ClassifierEditor.getValue().getClass().getName();
        else
            tclassifierName = classifierName;
        if (classifierResults != null) {
            classifierResults.updateList(classifierResults.getClassifierList(), "Classifier: ",
                    tclassifierName);
            classifierResults.updateList(classifierResults.getClassifierList(), "Training Data: ",
                    " Jack Knife Validation");
            classifierResults.updateList(classifierResults.getClassifierList(), "Time Used: ",
                    Utils.doubleToString(trainTimeElapsed / 1000.0, 2) + " seconds");
        }
        String classifierOneFilename = applicationData.getWorkingDirectory() + File.separator + "ClassifierOne_"
                + randomNumberForClassifier + ".scores";
        BufferedWriter outputCrossValidation = new BufferedWriter(new FileWriter(classifierOneFilename));

        //Instances foldTrainingInstance;
        //Instances foldTestingInstance;
        int positiveDataset1FromInt = applicationData.getPositiveDataset1FromField();
        int positiveDataset1ToInt = applicationData.getPositiveDataset1ToField();
        int negativeDataset1FromInt = applicationData.getNegativeDataset1FromField();
        int negativeDataset1ToInt = applicationData.getNegativeDataset1ToField();
        Step1TableModel positiveStep1TableModel = applicationData.getPositiveStep1TableModel();
        Step1TableModel negativeStep1TableModel = applicationData.getNegativeStep1TableModel();
        FastaFileManipulation fastaFile = new FastaFileManipulation(positiveStep1TableModel,
                negativeStep1TableModel, positiveDataset1FromInt, positiveDataset1ToInt,
                negativeDataset1FromInt, negativeDataset1ToInt, applicationData.getWorkingDirectory());
        FastaFormat fastaFormat;
        String header[] = new String[fullInst.numInstances()];
        String data[] = new String[fullInst.numInstances()];
        int counter = 0;
        while ((fastaFormat = fastaFile.nextSequence("pos")) != null) {
            header[counter] = fastaFormat.getHeader();
            data[counter] = fastaFormat.getSequence();
            counter++;
        }
        while ((fastaFormat = fastaFile.nextSequence("neg")) != null) {
            header[counter] = fastaFormat.getHeader();
            data[counter] = fastaFormat.getSequence();
            counter++;
        }

        //run jack knife validation
        for (int x = 0; x < fullInst.numInstances(); x++) {
            if (applicationData.terminateThread == true) {
                if (statusPane != null)
                    statusPane.setText("Interrupted - Classifier One Training Completed");
                outputCrossValidation.close();
                return classifierOne;
            }
            if (statusPane != null)
                statusPane.setText("Running " + (x + 1) + " / " + fullInst.numInstances());
            Instances trainPosInst = new Instances(posInst);
            Instances trainNegInst = new Instances(negInst);
            Instance testInst;
            //split data into training and testing
            if (x < trainPosInst.numInstances()) {
                testInst = posInst.instance(x);
                trainPosInst.delete(x);
            } else {
                testInst = negInst.instance(x - posInst.numInstances());
                trainNegInst.delete(x - posInst.numInstances());
            }
            Instances trainInstances;
            if (trainPosInst.numInstances() < trainNegInst.numInstances()) {
                trainInstances = new Instances(trainPosInst);
                int max = (int) (ratio * trainPosInst.numInstances());
                if (ratio == -1)
                    max = trainNegInst.numInstances();
                Random rand = new Random(1);
                for (int y = 0; y < trainNegInst.numInstances() && y < max; y++) {
                    int index = rand.nextInt(trainNegInst.numInstances());
                    trainInstances.add(trainNegInst.instance(index));
                    trainNegInst.delete(index);
                }
            } else {
                trainInstances = new Instances(trainNegInst);
                int max = (int) (ratio * trainNegInst.numInstances());
                if (ratio == -1)
                    max = trainPosInst.numInstances();
                Random rand = new Random(1);
                for (int y = 0; y < trainPosInst.numInstances() && y < max; y++) {
                    int index = rand.nextInt(trainPosInst.numInstances());
                    trainInstances.add(trainPosInst.instance(index));
                    trainPosInst.delete(index);
                }
            }
            Classifier foldClassifier = tempClassifier;
            foldClassifier.buildClassifier(trainInstances);
            double[] results = foldClassifier.distributionForInstance(testInst);
            int classIndex = testInst.classIndex();
            String classValue = testInst.toString(classIndex);
            outputCrossValidation.write(header[x]);
            outputCrossValidation.newLine();
            outputCrossValidation.write(data[x]);
            outputCrossValidation.newLine();
            if (classValue.equals("pos"))
                outputCrossValidation.write("pos,0=" + results[0]);
            else if (classValue.equals("neg"))
                outputCrossValidation.write("neg,0=" + results[0]);
            else {
                outputCrossValidation.close();
                throw new Error("Invalid Class Type!");
            }
            outputCrossValidation.newLine();
            outputCrossValidation.flush();
        }
        outputCrossValidation.close();
        PredictionStats classifierOneStatsOnJackKnife = new PredictionStats(classifierOneFilename, range,
                threshold);
        totalTimeElapsed = System.currentTimeMillis() - totalTimeStart;
        if (classifierResults != null)
            classifierResults.updateList(classifierResults.getResultsList(), "Total Time Used: ",
                    Utils.doubleToString(totalTimeElapsed / 60000, 2) + " minutes "
                            + Utils.doubleToString((totalTimeElapsed / 1000.0) % 60.0, 2) + " seconds");

        //if(classifierOneDisplayTextArea != null)
        classifierOneStatsOnJackKnife.updateDisplay(classifierResults, classifierOneDisplayTextArea, true);
        applicationData.setClassifierOneStats(classifierOneStatsOnJackKnife);
        if (myGraph != null)
            myGraph.setMyStats(classifierOneStatsOnJackKnife);

        if (statusPane != null)
            statusPane.setText("Done!");
        if (returnClassifier)
            return classifierOne;
        else
            return classifierOneStatsOnJackKnife;
    } catch (Exception e) {
        e.printStackTrace();
        JOptionPane.showMessageDialog(parent, e.getMessage(), "ERROR", JOptionPane.ERROR_MESSAGE);
        return null;
    }
}

From source file:smo2.SMO.java

License:Open Source License

/**
 * Method for building the classifier. Implements a one-against-one wrapper
 * for multi-class problems./* www .ja v  a2s. c  o  m*/
 *
 * @param insts
 *            the set of training instances
 * @exception Exception
 *                if the classifier can't be built successfully
 */
public void buildClassifier(Instances insts) throws Exception {

    if (!m_checksTurnedOff) {
        if (insts.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
        }
        if (insts.classAttribute().isNumeric()) {
            throw new UnsupportedClassTypeException(
                    "mySMO can't handle a numeric class! Use" + "SMOreg for performing regression.");
        }
        insts = new Instances(insts);
        insts.deleteWithMissingClass();
        if (insts.numInstances() == 0) {
            throw new Exception("No training instances without a missing class!");
        }

        /*
         * Removes all the instances with weight equal to 0. MUST be done
         * since condition (8) of Keerthi's paper is made with the assertion
         * Ci > 0 (See equation (3a).
         */
        Instances data = new Instances(insts, insts.numInstances());
        for (int i = 0; i < insts.numInstances(); i++) {
            if (insts.instance(i).weight() > 0)
                data.add(insts.instance(i));
        }
        if (data.numInstances() == 0) {
            throw new Exception("No training instances left after removing "
                    + "instance with either a weight null or a missing class!");
        }
        insts = data;

    }

    m_onlyNumeric = true;
    if (!m_checksTurnedOff) {
        for (int i = 0; i < insts.numAttributes(); i++) {
            if (i != insts.classIndex()) {
                if (!insts.attribute(i).isNumeric()) {
                    m_onlyNumeric = false;
                    break;
                }
            }
        }
    }

    if (!m_checksTurnedOff) {
        m_Missing = new ReplaceMissingValues();
        m_Missing.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Missing);
    } else {
        m_Missing = null;
    }

    if (!m_onlyNumeric) {
        m_NominalToBinary = new NominalToBinary();
        m_NominalToBinary.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_NominalToBinary);
    } else {
        m_NominalToBinary = null;
    }

    if (m_filterType == FILTER_STANDARDIZE) {
        m_Filter = new Standardize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else if (m_filterType == FILTER_NORMALIZE) {
        m_Filter = new Normalize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else {
        m_Filter = null;
    }

    m_classIndex = insts.classIndex();
    m_classAttribute = insts.classAttribute();

    // Generate subsets representing each class
    Instances[] subsets = new Instances[insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i] = new Instances(insts, insts.numInstances());
    }
    for (int j = 0; j < insts.numInstances(); j++) {
        Instance inst = insts.instance(j);
        subsets[(int) inst.classValue()].add(inst);
    }
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i].compactify();
    }

    // Build the binary classifiers
    Random rand = new Random(m_randomSeed);
    m_classifiers = new BinarymySMO[insts.numClasses()][insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        for (int j = i + 1; j < insts.numClasses(); j++) {
            m_classifiers[i][j] = new BinarymySMO();
            Instances data = new Instances(insts, insts.numInstances());
            for (int k = 0; k < subsets[i].numInstances(); k++) {
                data.add(subsets[i].instance(k));
            }
            for (int k = 0; k < subsets[j].numInstances(); k++) {
                data.add(subsets[j].instance(k));
            }
            data.compactify();
            data.randomize(rand);
            m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed);
        }
    }
}

From source file:svmal.SVMStrategy.java

public static Instances InstancesToInstances2(Instances insts) {
    Instances result = new Instances(insts, 0, 0);
    for (int i = 0; i < insts.numInstances(); i++) {
        Instance orig = insts.get(i);//from   w w w .  j  a  va 2  s  .c  o  m
        Instance2 inst2 = new Instance2(orig.weight(), orig.toDoubleArray());
        inst2.setDataset(result);
        result.add(inst2);
    }
    return result;
}

From source file:svmal.SVMStrategy.java

public static Instances PatternsToInstances2(Pattern[] patts) {
    Instances result = new Instances(patts[0].dataset(), 0, 0);
    for (Pattern orig : patts) {
        Instance2 inst2 = new Instance2(orig.weight(), orig.toDoubleArray());
        inst2.setIndex(orig.id());/*from  w  w  w.j a  va2  s  .  c o  m*/
        inst2.setDataset(result);
        result.add(inst2);
    }
    return result;
}