Example usage for weka.core Instances randomize

List of usage examples for weka.core Instances randomize

Introduction

In this page you can find the example usage for weka.core Instances randomize.

Prototype

public void randomize(Random random) 

Source Link

Document

Shuffles the instances in the set so that they are ordered randomly.

Usage

From source file:Bilbo.java

License:Open Source License

/**
 * Returns a training set for a particular iteration.
 * /*from  ww w  .j  ava2  s. c om*/
 * @param iteration the number of the iteration for the requested training set.
 * @return the training set for the supplied iteration number
 * @throws Exception if something goes wrong when generating a training set.
 */
@Override
protected synchronized Instances getTrainingSet(Instances p_data, int iteration) throws Exception {
    int bagSize = (int) (p_data.numInstances() * (m_BagSizePercent / 100.0));
    Instances bagData = null;
    Random r = new Random(m_Seed + iteration);

    // create the in-bag dataset
    if (m_CalcOutOfBag && p_data.classIndex() != -1) {
        m_inBag[iteration] = new boolean[p_data.numInstances()];
        bagData = p_data.resampleWithWeights(r, m_inBag[iteration], getRepresentCopiesUsingWeights());
    } else {
        bagData = p_data.resampleWithWeights(r, getRepresentCopiesUsingWeights());
        if (bagSize < p_data.numInstances()) {
            bagData.randomize(r);
            Instances newBagData = new Instances(bagData, 0, bagSize);
            bagData = newBagData;
        }
    }

    return bagData;
}

From source file:BaggingImprove.java

/**
 * Bagging method./*from w  ww. j  a  v  a  2s. c om*/
 *
 * @param data the training data to be used for generating the bagged
 * classifier.
 * @throws Exception if the classifier could not be built successfully
 */
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    //data.deleteWithMissingClass();

    super.buildClassifier(data);

    if (m_CalcOutOfBag && (m_BagSizePercent != 100)) {
        throw new IllegalArgumentException(
                "Bag size needs to be 100% if " + "out-of-bag error is to be calculated!");
    }
    //+
    System.out.println("Classifier length" + m_Classifiers.length);

    int bagSize = data.numInstances() * m_BagSizePercent / 100;
    //+
    System.out.println("Bag Size " + bagSize);

    Random random = new Random(m_Seed);

    boolean[][] inBag = null;
    if (m_CalcOutOfBag) {
        inBag = new boolean[m_Classifiers.length][];
    }

    //+
    //inisialisasi nama penamaan model
    BufferedWriter writer = new BufferedWriter(new FileWriter("Bootstrap.txt"));

    for (int j = 0; j < m_Classifiers.length; j++) {

        Instances bagData = null;

        // create the in-bag dataset
        if (m_CalcOutOfBag) {
            inBag[j] = new boolean[data.numInstances()];

            //System.out.println("Inbag1 " + inBag[0][1]);
            //bagData = resampleWithWeights(data, random, inBag[j]);
            bagData = data.resampleWithWeights(random, inBag[j]);
            //System.out.println("num after resample " + bagData.numInstances());
            //+
            //                for (int k = 0; k < bagData.numInstances(); k++) {
            //                    System.out.println("Bag Data after resample [calc out bag]" + bagData.instance(k));
            //                }

        } else {
            //+
            System.out.println("Not m_Calc out of bag");
            System.out.println("Please configure code inside!");

            bagData = data.resampleWithWeights(random);
            if (bagSize < data.numInstances()) {
                bagData.randomize(random);
                Instances newBagData = new Instances(bagData, 0, bagSize);
                bagData = newBagData;
            }
        }

        if (m_Classifier instanceof Randomizable) {
            //+
            System.out.println("Randomizable");
            ((Randomizable) m_Classifiers[j]).setSeed(random.nextInt());
        }

        //write bootstrap into file
        writer.write("Bootstrap " + j);
        writer.newLine();
        writer.write(bagData.toString());
        writer.newLine();

        System.out.println("Berhasil menyimpan bootstrap ke file ");

        System.out.println("Bootstrap " + j + 1);
        //            textarea.append("\nBootsrap " + (j + 1));
        //System.out.println("num instance kedua kali "+bagData.numInstances());

        for (int b = 1; b < bagData.numInstances(); b++) {
            System.out.println("" + bagData.instance(b));
            //                textarea.append("\n" + bagData.instance(b));
        }
        //            //+

        // build the classifier
        m_Classifiers[j].buildClassifier(bagData);
        //            //+
        //            
        //            SerializationHelper serialization = new SerializationHelper();
        //            serialization.write("KnnData"+model+".model", m_Classifiers[j]);
        //            System.out.println("Finish write into model");
        //            model++;
    }

    writer.flush();
    writer.close();
    // calc OOB error?
    if (getCalcOutOfBag()) {
        double outOfBagCount = 0.0;
        double errorSum = 0.0;
        boolean numeric = data.classAttribute().isNumeric();

        for (int i = 0; i < data.numInstances(); i++) {
            double vote;
            double[] votes;
            if (numeric) {
                votes = new double[1];
            } else {
                votes = new double[data.numClasses()];
            }

            // determine predictions for instance
            int voteCount = 0;
            for (int j = 0; j < m_Classifiers.length; j++) {
                if (inBag[j][i]) {
                    continue;
                }
                voteCount++;
                // double pred = m_Classifiers[j].classifyInstance(data.instance(i));
                if (numeric) {
                    // votes[0] += pred;
                    votes[0] = m_Classifiers[j].classifyInstance(data.instance(i));
                } else {
                    // votes[(int) pred]++;
                    double[] newProbs = m_Classifiers[j].distributionForInstance(data.instance(i));
                    //-
                    //                        for(double a : newProbs)
                    //                        {
                    //                            System.out.println("Double new probs %.f "+a);
                    //                        }
                    // average the probability estimates
                    for (int k = 0; k < newProbs.length; k++) {
                        votes[k] += newProbs[k];
                    }

                }
            }
            System.out.println("Vote count %d" + voteCount);

            // "vote"
            if (numeric) {
                vote = votes[0];
                if (voteCount > 0) {
                    vote /= voteCount; // average
                }
            } else {
                if (Utils.eq(Utils.sum(votes), 0)) {
                } else {
                    Utils.normalize(votes);

                }
                vote = Utils.maxIndex(votes); // predicted class
                //-
                System.out.println("Vote " + vote);

            }

            // error for instance
            outOfBagCount += data.instance(i).weight();
            if (numeric) {
                errorSum += StrictMath.abs(vote - data.instance(i).classValue()) * data.instance(i).weight();
            } else if (vote != data.instance(i).classValue()) {
                //+
                System.out.println("Vote terakhir" + data.instance(i).classValue());
                errorSum += data.instance(i).weight();
            }
        }

        m_OutOfBagError = errorSum / outOfBagCount;
    } else {
        m_OutOfBagError = 0;
    }
}

From source file:CrossValidationMultipleRuns.java

License:Open Source License

/**
 * Performs the cross-validation. See Javadoc of class for information
 * on command-line parameters./* www . jav a  2 s  . co m*/
 *
 * @param args   the command-line parameters
 * @throws Exception   if something goes wrong
 */
public static void main(String[] args) throws Exception {
    // loads data and set class index
    Instances data = DataSource.read(Utils.getOption("t", args));
    String clsIndex = Utils.getOption("c", args);
    if (clsIndex.length() == 0)
        clsIndex = "last";
    if (clsIndex.equals("first"))
        data.setClassIndex(0);
    else if (clsIndex.equals("last"))
        data.setClassIndex(data.numAttributes() - 1);
    else
        data.setClassIndex(Integer.parseInt(clsIndex) - 1);

    // classifier
    String[] tmpOptions;
    String classname;
    tmpOptions = Utils.splitOptions(Utils.getOption("W", args));
    classname = tmpOptions[0];
    tmpOptions[0] = "";
    Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);

    // other options
    int runs = Integer.parseInt(Utils.getOption("r", args));
    int folds = Integer.parseInt(Utils.getOption("x", args));

    // perform cross-validation
    for (int i = 0; i < runs; i++) {
        // randomize data
        int seed = i + 1;
        Random rand = new Random(seed);
        Instances randData = new Instances(data);
        randData.randomize(rand);
        //if (randData.classAttribute().isNominal())
        //   randData.stratify(folds);

        Evaluation eval = new Evaluation(randData);

        StringBuilder optionsString = new StringBuilder();
        for (String s : cls.getOptions()) {
            optionsString.append(s);
            optionsString.append(" ");
        }

        // output evaluation
        System.out.println();
        System.out.println("=== Setup run " + (i + 1) + " ===");
        System.out.println("Classifier: " + optionsString.toString());
        System.out.println("Dataset: " + data.relationName());
        System.out.println("Folds: " + folds);
        System.out.println("Seed: " + seed);
        System.out.println();

        for (int n = 0; n < folds; n++) {
            Instances train = randData.trainCV(folds, n);
            Instances test = randData.testCV(folds, n);

            // build and evaluate classifier
            Classifier clsCopy = Classifier.makeCopy(cls);
            clsCopy.buildClassifier(train);
            eval.evaluateModel(clsCopy, test);
            System.out.println(eval.toClassDetailsString());
        }

        System.out.println(
                eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i + 1) + " ===", false));
    }
}

From source file:REPTree.java

License:Open Source License

/**
 * Builds classifier.//from  w  ww .  ja  v a2  s .  co  m
 * 
 * @param data the data to train with
 * @throws Exception if building fails
 */
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_Seed);

    m_zeroR = null;
    if (data.numAttributes() == 1) {
        m_zeroR = new ZeroR();
        m_zeroR.buildClassifier(data);
        return;
    }

    // Randomize and stratify
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(m_NumFolds);
    }

    // Split data into training and pruning set
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
        train = data.trainCV(m_NumFolds, 0, random);
        prune = data.testCV(m_NumFolds, 0);
    } else {
        train = data;
    }

    // Create array of sorted indices and weights
    int[][][] sortedIndices = new int[1][train.numAttributes()][0];
    double[][][] weights = new double[1][train.numAttributes()][0];
    double[] vals = new double[train.numInstances()];
    for (int j = 0; j < train.numAttributes(); j++) {
        if (j != train.classIndex()) {
            weights[0][j] = new double[train.numInstances()];
            if (train.attribute(j).isNominal()) {

                // Handling nominal attributes. Putting indices of
                // instances with missing values at the end.
                sortedIndices[0][j] = new int[train.numInstances()];
                int count = 0;
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (!inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
            } else {

                // Sorted indices are computed for numeric attributes
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    vals[i] = inst.value(j);
                }
                sortedIndices[0][j] = Utils.sort(vals);
                for (int i = 0; i < train.numInstances(); i++) {
                    weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
                }
            }
        }
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
        Instance inst = train.instance(i);
        if (data.classAttribute().isNominal()) {
            classProbs[(int) inst.classValue()] += inst.weight();
            totalWeight += inst.weight();
        } else {
            classProbs[0] += inst.classValue() * inst.weight();
            totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
            totalWeight += inst.weight();
        }
    }
    m_Tree = new Tree();
    double trainVariance = 0;
    if (data.classAttribute().isNumeric()) {
        trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
        classProbs[0] /= totalWeight;
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum,
            m_MinVarianceProp * trainVariance, 0, m_MaxDepth);

    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
        m_Tree.insertHoldOutSet(prune);
        m_Tree.reducedErrorPrune();
        m_Tree.backfitHoldOutSet();
    }
}

From source file:Pair.java

License:Open Source License

/**
 * Boosting method. Boosts any classifier that can handle weighted
 * instances.// ww  w.  j  a v a2s  .com
 *
 * @param data the training data to be used for generating the
 * boosted classifier.
 * @exception Exception if the classifier could not be built successfully
 */
protected void buildClassifierWithWeights(Instances data) throws Exception {

    Random randomInstance = new Random(0);
    double epsilon, reweight, beta = 0;
    Evaluation evaluation;
    Instances sample;
    // Initialize data
    m_Betas = new double[m_Classifiers.length];
    m_NumIterationsPerformed = 0;
    int numSourceInstances = m_SourceInstances.numInstances();

    // Do boostrap iterations
    for (m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++) {
        // Build the classifier
        sample = null;
        if (!resample || m_NumIterationsPerformed == 0) {
            sample = data;
        } else {
            double sum = data.sumOfWeights();
            double[] weights = new double[data.numInstances()];
            for (int i = 0; i < weights.length; i++) {
                weights[i] = data.instance(i).weight() / sum;
            }
            sample = data.resampleWithWeights(randomInstance, weights);

            if (doSampleSize) {
                int effectiveInstances = (int) (sourceFraction * weights.length + numTargetInstances);
                if (effectiveInstances > numSourceInstances + numTargetInstances)
                    effectiveInstances = numSourceInstances + numTargetInstances;
                //System.out.println(effectiveInstances);               
                sample.randomize(randomInstance);
                Instances q = new Instances(sample, 0, effectiveInstances);
                sample = q;
            }
        }
        try {
            m_Classifiers[m_NumIterationsPerformed].buildClassifier(sample);
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("E: " + e);
        }

        if (doBagging)
            beta = 0.4 / .6; //always same beta
        else
            beta = setWeights(data, m_Classifiers[m_NumIterationsPerformed], -1, numSourceInstances, true);

        // Stop if error too small or error too big and ignore this model
        if (beta < 0) { //setWeights indicates a problem with negative beta
            if (m_NumIterationsPerformed == 0) {
                m_NumIterationsPerformed = 1; // If we're the first we have to to use it
            }
            break;
        }

        // Determine the weight to assign to this model

        m_Betas[m_NumIterationsPerformed] = Math.log(1 / beta);

    }

    betaSum = 0;

    for (int i = 0; i < m_NumIterationsPerformed; i++)
        betaSum += m_Betas[i];
}

From source file:CopiaSeg3.java

public static Instances[] split(Instances data, int numberOfFolds) {
    Instances[] split = new Instances[2];

    Random semilla = new Random();
    int seed = semilla.nextInt(20); // Genera una semilla aleatorio entre 0 y 20
    Random rand = new Random(seed); // Create seeded number generator
    Instances randData = new Instances(data); // Crea una copia de los datos originales
    randData.randomize(rand); // Ordena los datos de forma aleatoria

    split[0] = randData.trainCV(numberOfFolds, 0);
    split[1] = randData.testCV(numberOfFolds, 0);

    return split;
}

From source file:REPRandomTree.java

License:Open Source License

/**
 * Builds classifier.//from w w w. j  a v  a2  s  .  c o m
 * 
 * @param data the data to train with
 * @throws Exception if building fails
 */
public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    Random random = new Random(m_Seed);

    m_zeroR = null;
    if (data.numAttributes() == 1) {
        m_zeroR = new ZeroR();
        m_zeroR.buildClassifier(data);
        return;
    }

    // Randomize and stratify
    data.randomize(random);
    if (data.classAttribute().isNominal()) {
        data.stratify(m_NumFolds);
    }

    // Split data into training and pruning set
    Instances train = null;
    Instances prune = null;
    if (!m_NoPruning) {
        train = data.trainCV(m_NumFolds, 0, random);
        prune = data.testCV(m_NumFolds, 0);
    } else {
        train = data;
    }

    // Create array of sorted indices and weights
    int[][][] sortedIndices = new int[1][train.numAttributes()][0];
    double[][][] weights = new double[1][train.numAttributes()][0];
    double[] vals = new double[train.numInstances()];
    for (int j = 0; j < train.numAttributes(); j++) {
        if (j != train.classIndex()) {
            weights[0][j] = new double[train.numInstances()];
            if (train.attribute(j).isNominal()) {

                // Handling nominal attributes. Putting indices of
                // instances with missing values at the end.
                sortedIndices[0][j] = new int[train.numInstances()];
                int count = 0;
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (!inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    if (inst.isMissing(j)) {
                        sortedIndices[0][j][count] = i;
                        weights[0][j][count] = inst.weight();
                        count++;
                    }
                }
            } else {

                // Sorted indices are computed for numeric attributes
                for (int i = 0; i < train.numInstances(); i++) {
                    Instance inst = train.instance(i);
                    vals[i] = inst.value(j);
                }
                sortedIndices[0][j] = Utils.sort(vals);
                for (int i = 0; i < train.numInstances(); i++) {
                    weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight();
                }
            }
        }
    }

    // Compute initial class counts
    double[] classProbs = new double[train.numClasses()];
    double totalWeight = 0, totalSumSquared = 0;
    for (int i = 0; i < train.numInstances(); i++) {
        Instance inst = train.instance(i);
        if (data.classAttribute().isNominal()) {
            classProbs[(int) inst.classValue()] += inst.weight();
            totalWeight += inst.weight();
        } else {
            classProbs[0] += inst.classValue() * inst.weight();
            totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
            totalWeight += inst.weight();
        }
    }
    m_Tree = new Tree();
    double trainVariance = 0;
    if (data.classAttribute().isNumeric()) {
        trainVariance = m_Tree.singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
        classProbs[0] /= totalWeight;
    }

    // Build tree
    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum,
            m_MinVarianceProp * trainVariance, 0, m_MaxDepth, m_FeatureFrac, random);

    // Insert pruning data and perform reduced error pruning
    if (!m_NoPruning) {
        m_Tree.insertHoldOutSet(prune);
        m_Tree.reducedErrorPrune();
        m_Tree.backfitHoldOutSet();
    }
}

From source file:SMO.java

License:Open Source License

/**
 * Method for building the classifier. Implements a one-against-one
 * wrapper for multi-class problems.//  w  ww.java  2s .  com
 *
 * @param insts the set of training instances
 * @throws Exception if the classifier can't be built successfully
 */
public void buildClassifier(Instances insts) throws Exception {

    if (!m_checksTurnedOff) {
        // can classifier handle the data?
        getCapabilities().testWithFail(insts);

        // remove instances with missing class
        insts = new Instances(insts);
        insts.deleteWithMissingClass();

        /* Removes all the instances with weight equal to 0.
         MUST be done since condition (8) of Keerthi's paper 
         is made with the assertion Ci > 0 (See equation (3a). */
        Instances data = new Instances(insts, insts.numInstances());
        for (int i = 0; i < insts.numInstances(); i++) {
            if (insts.instance(i).weight() > 0)
                data.add(insts.instance(i));
        }
        if (data.numInstances() == 0) {
            throw new Exception("No training instances left after removing " + "instances with weight 0!");
        }
        insts = data;
    }

    if (!m_checksTurnedOff) {
        m_Missing = new ReplaceMissingValues();
        m_Missing.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Missing);
    } else {
        m_Missing = null;
    }

    if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) {
        boolean onlyNumeric = true;
        if (!m_checksTurnedOff) {
            for (int i = 0; i < insts.numAttributes(); i++) {
                if (i != insts.classIndex()) {
                    if (!insts.attribute(i).isNumeric()) {
                        onlyNumeric = false;
                        break;
                    }
                }
            }
        }

        if (!onlyNumeric) {
            m_NominalToBinary = new NominalToBinary();
            m_NominalToBinary.setInputFormat(insts);
            insts = Filter.useFilter(insts, m_NominalToBinary);
        } else {
            m_NominalToBinary = null;
        }
    } else {
        m_NominalToBinary = null;
    }

    if (m_filterType == FILTER_STANDARDIZE) {
        m_Filter = new Standardize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else if (m_filterType == FILTER_NORMALIZE) {
        m_Filter = new Normalize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else {
        m_Filter = null;
    }

    m_classIndex = insts.classIndex();
    m_classAttribute = insts.classAttribute();
    m_KernelIsLinear = (m_kernel instanceof PolyKernel) && (((PolyKernel) m_kernel).getExponent() == 1.0);

    // Generate subsets representing each class
    Instances[] subsets = new Instances[insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i] = new Instances(insts, insts.numInstances());
    }
    for (int j = 0; j < insts.numInstances(); j++) {
        Instance inst = insts.instance(j);
        subsets[(int) inst.classValue()].add(inst);
    }
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i].compactify();
    }

    // Build the binary classifiers
    Random rand = new Random(m_randomSeed);
    m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        for (int j = i + 1; j < insts.numClasses(); j++) {
            m_classifiers[i][j] = new BinarySMO();
            m_classifiers[i][j].setKernel(Kernel.makeCopy(getKernel()));
            Instances data = new Instances(insts, insts.numInstances());
            for (int k = 0; k < subsets[i].numInstances(); k++) {
                data.add(subsets[i].instance(k));
            }
            for (int k = 0; k < subsets[j].numInstances(); k++) {
                data.add(subsets[j].instance(k));
            }
            data.compactify();
            data.randomize(rand);
            m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed);
        }
    }
}

From source file:WekaClassify.java

License:Open Source License

public void SplitRun(Instances anInstance, double split) throws Exception {

    anInstance.randomize(new java.util.Random(0));

    int trainSize = (int) Math.round(anInstance.numInstances() * split);
    int testSize = anInstance.numInstances() - trainSize;
    Instances train = new Instances(anInstance, 0, trainSize);
    Instances test = new Instances(anInstance, trainSize, testSize);

    NumberFormat percentFormat = NumberFormat.getPercentInstance();
    percentFormat.setMaximumFractionDigits(1);
    String result = percentFormat.format(split);

    System.out.println("Evaluating with : Split percentage : " + result);
    m_Evaluation.evaluateModel(m_Classifier, anInstance);
}

From source file:adams.flow.transformer.WekaAttributeSelection.java

License:Open Source License

/**
 * Executes the flow item.//from ww  w.  ja  v a  2 s .c  o  m
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    Instances data;
    Instances reduced;
    Instances transformed;
    AttributeSelection eval;
    boolean crossValidate;
    int fold;
    Instances train;
    WekaAttributeSelectionContainer cont;
    SpreadSheet stats;
    int i;
    Row row;
    int[] selected;
    double[][] ranked;
    Range range;
    String rangeStr;
    boolean useReduced;

    result = null;

    try {
        if (m_InputToken.getPayload() instanceof Instances)
            data = (Instances) m_InputToken.getPayload();
        else
            data = (Instances) ((WekaTrainTestSetContainer) m_InputToken.getPayload())
                    .getValue(WekaTrainTestSetContainer.VALUE_TRAIN);

        if (result == null) {
            crossValidate = (m_Folds >= 2);

            // setup evaluation
            eval = new AttributeSelection();
            eval.setEvaluator(m_Evaluator);
            eval.setSearch(m_Search);
            eval.setFolds(m_Folds);
            eval.setSeed((int) m_Seed);
            eval.setXval(crossValidate);

            // select attributes
            if (crossValidate) {
                Random random = new Random(m_Seed);
                data = new Instances(data);
                data.randomize(random);
                if ((data.classIndex() > -1) && data.classAttribute().isNominal()) {
                    if (isLoggingEnabled())
                        getLogger().info("Stratifying instances...");
                    data.stratify(m_Folds);
                }
                for (fold = 0; fold < m_Folds; fold++) {
                    if (isLoggingEnabled())
                        getLogger().info("Creating splits for fold " + (fold + 1) + "...");
                    train = data.trainCV(m_Folds, fold, random);
                    if (isLoggingEnabled())
                        getLogger().info("Selecting attributes using all but fold " + (fold + 1) + "...");
                    eval.selectAttributesCVSplit(train);
                }
            } else {
                eval.SelectAttributes(data);
            }

            // generate reduced/transformed dataset
            reduced = null;
            transformed = null;
            if (!crossValidate) {
                reduced = eval.reduceDimensionality(data);
                if (m_Evaluator instanceof AttributeTransformer)
                    transformed = ((AttributeTransformer) m_Evaluator).transformedData(data);
            }

            // generated stats
            stats = null;
            if (!crossValidate) {
                stats = new DefaultSpreadSheet();
                row = stats.getHeaderRow();

                useReduced = false;
                if (m_Search instanceof RankedOutputSearch) {
                    i = reduced.numAttributes();
                    if (reduced.classIndex() > -1)
                        i--;
                    ranked = eval.rankedAttributes();
                    useReduced = (ranked.length == i);
                }

                if (useReduced) {
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell("" + i).setContent(reduced.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < reduced.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                } else {
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell("" + i).setContent(data.attribute(i).name());
                    row = stats.addRow();
                    for (i = 0; i < data.numAttributes(); i++)
                        row.addCell(i).setContent(0.0);
                }

                if (m_Search instanceof RankedOutputSearch) {
                    ranked = eval.rankedAttributes();
                    for (i = 0; i < ranked.length; i++)
                        row.getCell((int) ranked[i][0]).setContent(ranked[i][1]);
                } else {
                    selected = eval.selectedAttributes();
                    for (i = 0; i < selected.length; i++)
                        row.getCell(selected[i]).setContent(1.0);
                }
            }

            // selected attributes
            rangeStr = null;
            if (!crossValidate) {
                range = new Range();
                range.setIndices(eval.selectedAttributes());
                rangeStr = range.getRange();
            }

            // setup container
            if (crossValidate)
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, m_Seed, m_Folds);
            else
                cont = new WekaAttributeSelectionContainer(data, reduced, transformed, eval, stats, rangeStr);
            m_OutputToken = new Token(cont);
        }
    } catch (Exception e) {
        m_OutputToken = null;
        result = handleException("Failed to process data:", e);
    }

    return result;
}