Example usage for weka.core Instances stratify

List of usage examples for weka.core Instances stratify

Introduction

In this page you can find the example usage for weka.core Instances stratify.

Prototype

public void stratify(int numFolds) 

Source Link

Document

Stratifies a set of instances according to its class values if the class attribute is nominal (so that afterwards a stratified cross-validation can be performed).

Usage

From source file:REPTree.java

License:Open Source License

/**
 * Builds classifier./*w w w  .jav  a2  s . c om*/
 * 
 * @param data the data to train with
 * @throws Exception if building fails
 */
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_Seed);

    m_zeroR = null;
    if (data.numAttributes() == 1) {
        m_zeroR = new ZeroR();
        m_zeroR.buildClassifier(data);
        return;
    }

    // Randomize and stratify
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(m_NumFolds);
    }

    // Split data into training and pruning set
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
        train = data.trainCV(m_NumFolds, 0, random);
        prune = data.testCV(m_NumFolds, 0);
    } else {
        train = data;
    }

    // Create array of sorted indices and weights
    int[][][] sortedIndices = new int[1][train.numAttributes()][0];
    double[][][] weights = new double[1][train.numAttributes()][0];
    double[] vals = new double[train.numInstances()];
    for (int j = 0; j < train.numAttributes(); j++) {
        if (j != train.classIndex()) {
            weights[0][j] = new double[train.numInstances()];
            if (train.attribute(j).isNominal()) {

                // Handling nominal attributes. Putting indices of
                // instances with missing values at the end.
                sortedIndices[0][j] = new int[train.numInstances()];
                int count = 0;
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (!inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
            } else {

                // Sorted indices are computed for numeric attributes
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    vals[i] = inst.value(j);
                }
                sortedIndices[0][j] = Utils.sort(vals);
                for (int i = 0; i < train.numInstances(); i++) {
                    weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
                }
            }
        }
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
        Instance inst = train.instance(i);
        if (data.classAttribute().isNominal()) {
            classProbs[(int) inst.classValue()] += inst.weight();
            totalWeight += inst.weight();
        } else {
            classProbs[0] += inst.classValue() * inst.weight();
            totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
            totalWeight += inst.weight();
        }
    }
    m_Tree = new Tree();
    double trainVariance = 0;
    if (data.classAttribute().isNumeric()) {
        trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
        classProbs[0] /= totalWeight;
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum,
            m_MinVarianceProp * trainVariance, 0, m_MaxDepth);

    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
        m_Tree.insertHoldOutSet(prune);
        m_Tree.reducedErrorPrune();
        m_Tree.backfitHoldOutSet();
    }
}

From source file:REPRandomTree.java

License:Open Source License

/**
 * Builds classifier.//from  w w  w  . jav  a 2 s  . c  o  m
 * 
 * @param data the data to train with
 * @throws Exception if building fails
 */
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_Seed);

    m_zeroR = null;
    if (data.numAttributes() == 1) {
        m_zeroR = new ZeroR();
        m_zeroR.buildClassifier(data);
        return;
    }

    // Randomize and stratify
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(m_NumFolds);
    }

    // Split data into training and pruning set
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
        train = data.trainCV(m_NumFolds, 0, random);
        prune = data.testCV(m_NumFolds, 0);
    } else {
        train = data;
    }

    // Create array of sorted indices and weights
    int[][][] sortedIndices = new int[1][train.numAttributes()][0];
    double[][][] weights = new double[1][train.numAttributes()][0];
    double[] vals = new double[train.numInstances()];
    for (int j = 0; j < train.numAttributes(); j++) {
        if (j != train.classIndex()) {
            weights[0][j] = new double[train.numInstances()];
            if (train.attribute(j).isNominal()) {

                // Handling nominal attributes. Putting indices of
                // instances with missing values at the end.
                sortedIndices[0][j] = new int[train.numInstances()];
                int count = 0;
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (!inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
            } else {

                // Sorted indices are computed for numeric attributes
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    vals[i] = inst.value(j);
                }
                sortedIndices[0][j] = Utils.sort(vals);
                for (int i = 0; i < train.numInstances(); i++) {
                    weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
                }
            }
        }
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
        Instance inst = train.instance(i);
        if (data.classAttribute().isNominal()) {
            classProbs[(int) inst.classValue()] += inst.weight();
            totalWeight += inst.weight();
        } else {
            classProbs[0] += inst.classValue() * inst.weight();
            totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
            totalWeight += inst.weight();
        }
    }
    m_Tree = new Tree();
    double trainVariance = 0;
    if (data.classAttribute().isNumeric()) {
        trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
        classProbs[0] /= totalWeight;
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum,
            m_MinVarianceProp * trainVariance, 0, m_MaxDepth, m_FeatureFrac, random);

    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
        m_Tree.insertHoldOutSet(prune);
        m_Tree.reducedErrorPrune();
        m_Tree.backfitHoldOutSet();
    }
}

From source file:adams.flow.transformer.WekaAttributeSelection.java

License:Open Source License

/**
 * Executes the flow item.// ww  w. j  av a2  s  . co m
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    Instances data;
    Instances reduced;
    Instances transformed;
    AttributeSelection eval;
    boolean crossValidate;
    int fold;
    Instances train;
    WekaAttributeSelectionContainer cont;
    SpreadSheet stats;
    int i;
    Row row;
    int[] selected;
    double[][] ranked;
    Range range;
    String rangeStr;
    boolean useReduced;

    result = null;

    try {
        if (m_InputToken.getPayload() instanceof Instances)
            data = (Instances) m_InputToken.getPayload();
        else
            data = (Instances) ((WekaTrainTestSetContainer) m_InputToken.getPayload())
                    .getValue(WekaTrainTestSetContainer.VALUE_TRAIN);

        if (result == null) {
            crossValidate = (m_Folds >= 2);

            // setup evaluation
            eval = new AttributeSelection();
            eval.setEvaluator(m_Evaluator);
            eval.setSearch(m_Search);
            eval.setFolds(m_Folds);
            eval.setSeed((int) m_Seed);
            eval.setXval(crossValidate);

            // select attributes
            if (crossValidate) {
                Random random = new Random(m_Seed);
                data = new Instances(data);
                data.randomize(random);
                if ((data.classIndex() > -1) && data.classAttribute().isNominal()) {
                    if (isLoggingEnabled())
                        getLogger().info("Stratifying instances...");
                    data.stratify(m_Folds);
                }
                for (fold = 0; fold < m_Folds; fold++) {
                    if (isLoggingEnabled())
                        getLogger().info("Creating splits for fold " + (fold + 1) + "...");
                    train = data.trainCV(m_Folds, fold, random);
                    if (isLoggingEnabled())
                        getLogger().info("Selecting attributes using all but fold " + (fold + 1) + "...");
                    eval.selectAttributesCVSplit(train);
                }
            } else {
                eval.SelectAttributes(data);
            }

            // generate reduced/transformed dataset
            reduced = null;
            transformed = null;
            if (!crossValidate) {
                reduced = eval.reduceDimensionality(data);
                if (m_Evaluator instanceof AttributeTransformer)
                    transformed = ((AttributeTransformer) m_Evaluator).transformedData(data);
            }

            // generated stats
            stats = null;
            if (!crossValidate) {
                stats = new DefaultSpreadSheet();
                row = stats.getHeaderRow();

                useReduced = false;
                if (m_Search instanceof RankedOutputSearch) {
                    i = reduced.numAttributes();
                    if (reduced.classIndex() > -1)
                        i--;
                    ranked = eval.rankedAttributes();
                    useReduced = (ranked.length == i);
                }

                if (useReduced) {
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell("" + i).setContent(reduced.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                } else {
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell("" + i).setContent(data.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                }

                if (m_Search instanceof RankedOutputSearch) {
                    ranked = eval.rankedAttributes();
                    for (i = 0; i < ranked.length; i++)
                        row.getCell((int) ranked[i][0]).setContent(ranked[i][1]);
                } else {
                    selected = eval.selectedAttributes();
                    for (i = 0; i < selected.length; i++)
                        row.getCell(selected[i]).setContent(1.0);
                }
            }

            // selected attributes
            rangeStr = null;
            if (!crossValidate) {
                range = new Range();
                range.setIndices(eval.selectedAttributes());
                rangeStr = range.getRange();
            }

            // setup container
            if (crossValidate)
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, m_Seed, m_Folds);
            else
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, stats, rangeStr);
            m_OutputToken = new Token(cont);
        }
    } catch (Exception e) {
        m_OutputToken = null;
        result = handleException("Failed to process data:", e);
    }

    return result;
}

From source file:ann.MyANN.java

/**
 * Mengevaluasi model dengan membagi instances menjadi trainSet dan testSet sebanyak numFold
 * @param instances data yang akan diuji
 * @param numFold//from  w  w w.ja  v  a  2 s .co  m
 * @param rand 
 * @return confusion matrix
 */
public int[][] crossValidation(Instances instances, int numFold, Random rand) {
    int[][] totalResult = null;
    instances = new Instances(instances);
    instances.randomize(rand);
    if (instances.classAttribute().isNominal()) {
        instances.stratify(numFold);
    }
    for (int i = 0; i < numFold; i++) {
        try {
            // membagi instance berdasarkan jumlah fold
            Instances train = instances.trainCV(numFold, i, rand);
            Instances test = instances.testCV(numFold, i);
            MyANN cc = new MyANN(this);
            cc.buildClassifier(train);
            int[][] result = cc.evaluate(test);
            if (i == 0) {
                totalResult = cc.evaluate(test);
            } else {
                result = cc.evaluate(test);
                for (int j = 0; j < totalResult.length; j++) {
                    for (int k = 0; k < totalResult[0].length; k++) {
                        totalResult[j][k] += result[j][k];
                    }
                }
            }
        } catch (Exception ex) {
            Logger.getLogger(MyANN.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    return totalResult;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput// ww w.  ja  va2  s . c  om
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidation(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation ST");

    PerformanceCounters.startTimer("cross-validation init ST");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    Instances trainSets[] = new Instances[folds];
    Instances testSets[] = new Instances[folds];
    Classifier foldCls[] = new Classifier[folds];

    for (int n = 0; n < folds; n++) {
        trainSets[n] = randData.trainCV(folds, n);
        testSets[n] = randData.testCV(folds, n);
        foldCls[n] = AbstractClassifier.makeCopy(cls);
    }

    PerformanceCounters.stopTimer("cross-validation init ST");
    PerformanceCounters.startTimer("cross-validation folds+train ST");
    //paralelize!!:--------------------------------------------------------------
    for (int n = 0; n < folds; n++) {
        Instances train = trainSets[n];
        Instances test = testSets[n];

        // the above code is used by the StratifiedRemoveFolds filter, the
        // code below by the Explorer/Experimenter:
        // Instances train = randData.trainCV(folds, n, rand);
        // build and evaluate classifier
        Classifier clsCopy = foldCls[n];
        clsCopy.buildClassifier(train);
        eval.evaluateModel(clsCopy, test);
    }

    cls.buildClassifier(data);
    //until here!-----------------------------------------------------------------

    PerformanceCounters.stopTimer("cross-validation folds+train ST");
    PerformanceCounters.startTimer("cross-validation post ST");
    // output evaluation
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post ST");
    PerformanceCounters.stopTimer("cross-validation ST");

    return out;
}

From source file:asap.CrossValidation.java

/**
 *
 * @param dataInput//from  ww w  .  j  a  va2  s . c  o m
 * @param classIndex
 * @param removeIndices
 * @param cls
 * @param seed
 * @param folds
 * @param modelOutputFile
 * @return
 * @throws Exception
 */
public static String performCrossValidationMT(String dataInput, String classIndex, String removeIndices,
        AbstractClassifier cls, int seed, int folds, String modelOutputFile) throws Exception {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // loads data and set class index
    Instances data = DataSource.read(dataInput);
    String clsIndex = classIndex;

    switch (clsIndex) {
    case "first":
        data.setClassIndex(0);
        break;
    case "last":
        data.setClassIndex(data.numAttributes() - 1);
        break;
    default:
        try {
            data.setClassIndex(Integer.parseInt(clsIndex) - 1);
        } catch (NumberFormatException e) {
            data.setClassIndex(data.attribute(clsIndex).index());
        }
        break;
    }

    Remove removeFilter = new Remove();
    removeFilter.setAttributeIndices(removeIndices);
    removeFilter.setInputFormat(data);
    data = Filter.useFilter(data, removeFilter);

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval = new Evaluation(randData);
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                AbstractClassifier.makeCopy(cls)));

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        //use the current thread to run the cross-validation instead of using the Thread instance created here:
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    cls.buildClassifier(data);

    for (Thread foldThread : foldThreads) {
        foldThread.join();
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (!modelOutputFile.isEmpty()) {
        SerializationHelper.write(modelOutputFile, cls);
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.CrossValidation.java

static String performCrossValidationMT(Instances data, AbstractClassifier cls, int seed, int folds,
        String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation MT");

    PerformanceCounters.startTimer("cross-validation init MT");

    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(data);
    randData.randomize(rand);/*from ww w . j  av a2s  .c om*/
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(cls)));
        } catch (Exception ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }

        //TODO: use Config.getNumThreads() for limiting these::
        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init MT");
    PerformanceCounters.startTimer("cross-validation folds+train MT");
    //paralelize!!:--------------------------------------------------------------
    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    try {
        cls.buildClassifier(data);
    } catch (Exception ex) {
        Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
    }

    for (Thread foldThread : foldThreads) {
        try {
            foldThread.join();
        } catch (InterruptedException ex) {
            Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    //until here!-----------------------------------------------------------------
    PerformanceCounters.stopTimer("cross-validation folds+train MT");
    PerformanceCounters.startTimer("cross-validation post MT");
    // evaluation for output:
    String out = "\n" + "=== Setup ===\n" + "Classifier: " + cls.getClass().getName() + " "
            + Utils.joinOptions(cls.getOptions()) + "\n" + "Dataset: " + data.relationName() + "\n" + "Folds: "
            + folds + "\n" + "Seed: " + seed + "\n" + "\n"
            + eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false) + "\n";

    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, cls);
            } catch (Exception ex) {
                Logger.getLogger(CrossValidation.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation post MT");
    PerformanceCounters.stopTimer("cross-validation MT");
    return out;
}

From source file:asap.NLPSystem.java

private String crossValidate(int seed, int folds, String modelOutputFile) {

    PerformanceCounters.startTimer("cross-validation");
    PerformanceCounters.startTimer("cross-validation init");

    AbstractClassifier abstractClassifier = (AbstractClassifier) classifier;
    // randomize data
    Random rand = new Random(seed);
    Instances randData = new Instances(trainingSet);
    randData.randomize(rand);//  www  .ja  v a2  s . c  o  m
    if (randData.classAttribute().isNominal()) {
        randData.stratify(folds);
    }

    // perform cross-validation and add predictions
    Evaluation eval;
    try {
        eval = new Evaluation(randData);
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        return "Error creating evaluation instance for given data!";
    }
    List<Thread> foldThreads = (List<Thread>) Collections.synchronizedList(new LinkedList<Thread>());

    List<FoldSet> foldSets = (List<FoldSet>) Collections.synchronizedList(new LinkedList<FoldSet>());

    for (int n = 0; n < folds; n++) {
        try {
            foldSets.add(new FoldSet(randData.trainCV(folds, n), randData.testCV(folds, n),
                    AbstractClassifier.makeCopy(abstractClassifier)));
        } catch (Exception ex) {
            Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
        }

        if (n < Config.getNumThreads() - 1) {
            Thread foldThread = new Thread(new CrossValidationFoldThread(n, foldSets, eval));
            foldThreads.add(foldThread);
        }
    }

    PerformanceCounters.stopTimer("cross-validation init");
    PerformanceCounters.startTimer("cross-validation folds+train");

    if (Config.getNumThreads() > 1) {
        for (Thread foldThread : foldThreads) {
            foldThread.start();
        }
    } else {
        new CrossValidationFoldThread(0, foldSets, eval).run();
    }

    for (Thread foldThread : foldThreads) {
        while (foldThread.isAlive()) {
            try {
                foldThread.join();
            } catch (InterruptedException ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    PerformanceCounters.stopTimer("cross-validation folds+train");
    PerformanceCounters.startTimer("cross-validation post");
    // evaluation for output:
    String out = String.format(
            "\n=== Setup ===\nClassifier: %s %s\n" + "Dataset: %s\nFolds: %s\nSeed: %s\n\n%s\n",
            abstractClassifier.getClass().getName(), Utils.joinOptions(abstractClassifier.getOptions()),
            trainingSet.relationName(), folds, seed,
            eval.toSummaryString(String.format("=== %s-fold Cross-validation ===", folds), false));

    try {
        crossValidationPearsonsCorrelation = eval.correlationCoefficient();
    } catch (Exception ex) {
        Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
    }
    if (modelOutputFile != null) {
        if (!modelOutputFile.isEmpty()) {
            try {
                SerializationHelper.write(modelOutputFile, abstractClassifier);
            } catch (Exception ex) {
                Logger.getLogger(NLPSystem.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    classifierBuiltWithCrossValidation = true;
    PerformanceCounters.stopTimer("cross-validation post");
    PerformanceCounters.stopTimer("cross-validation");
    return out;
}

From source file:au.edu.usyd.it.yangpy.sampling.BPSO.java

License:Open Source License

/**
 * this method starts the under sampling procedure
 *//*from  ww w . j ava  2 s  .  co  m*/
public void underSampling() {
    // create a copy of original data set for cross validation
    Instances randData = new Instances(dataset);

    // dividing the data set to 3 folds
    randData.stratify(3);

    for (int fold = 0; fold < 3; fold++) {
        // using the first 2 folds as internal training set. the last fold as the internal test set.
        internalTrain = randData.trainCV(3, fold);
        internalTest = randData.testCV(3, fold);

        // calculate the number of the major class samples in the internal training set
        majorSize = 0;

        for (int i = 0; i < internalTrain.numInstances(); i++) {
            if (internalTrain.instance(i).classValue() == majorLabel) {
                majorSize++;
            }
        }

        // class variable initialization
        dec = new DecimalFormat("##.####");
        localBest = new double[popSize];
        localBestParticles = new int[popSize][majorSize];
        globalBest = Double.MIN_VALUE;
        globalBestParticle = new int[majorSize];
        velocity = new double[popSize][majorSize];
        particles = new int[popSize][majorSize];
        searchSpace = new double[popSize][majorSize];

        System.out.println("-------------------- parameters ----------------------");
        System.out.println("CV fold = " + fold);
        System.out.println("inertia weight = " + w);
        System.out.println("c1,c2 = " + c1);
        System.out.println("iteration time = " + iteration);
        System.out.println("population size = " + popSize);

        // initialize BPSO
        initialization();

        // perform optimization process
        findMaxFit();

        // save optimization results to array list
        saveResults();
    }

    // rank the selected samples and build the balanced dataset
    try {
        createBalanceData();
    } catch (IOException ioe) {
        ioe.printStackTrace();
    }

}

From source file:au.edu.usyd.it.yangpy.snp.ParallelGenetic.java

License:Open Source License

public void crossValidate() {
    // create a copy of original training set for CV
    Instances randData = new Instances(data);

    // divide the data set with x-fold stratify measure
    randData.stratify(foldSize);

    try {/*  w  w w.  j av  a 2  s.  c  o  m*/

        cvTrain = randData.trainCV(foldSize, foldIndex);
        cvTest = randData.testCV(foldSize, foldIndex);

        foldIndex++;

        if (foldIndex >= foldSize) {
            foldIndex = 0;
        }

    } catch (Exception e) {
        System.out.println(cvTest.toString());
    }
}