Example usage for weka.core Instances numInstances

List of usage examples for weka.core Instances numInstances

Introduction

In this page you can find the example usage for weka.core Instances numInstances.

Prototype


publicint numInstances() 

Source Link

Document

Returns the number of instances in the dataset.

Usage

From source file:iris.Network.java

@Override
public void buildClassifier(Instances trainingSet) throws Exception {
    trainingData = trainingSet;/*from w  ww  .  j  a  v a2 s .  co  m*/
    // Set the number of inputs to the network to the number of attributes
    // i.e., 4 for the IRIS set
    setInputCount(trainingSet.numAttributes() - 1);

    buildNetwork(inputCount, neuronsInEachLayer);

    List<Double> values = new ArrayList<>();
    //stores the values our output layer gave us after feed forward
    List<Double> finalValues = new ArrayList<>();

    for (int i = 0; i < trainingSet.numInstances(); i++) {
        // Set values of instance
        for (int j = 0; j < trainingSet.instance(i).numAttributes() - 1; j++) {
            values.add(trainingSet.instance(i).value(j));
        }
        //gets the new values of what we calculated for the classification
        //this is probably where we want to loop x amount of times
        for (int k = 0; k < numIterations; k++) {
            getOutputs(values);
            backPropogate(trainingSet.instance(i));
        }
        values.clear(); // reset list
    }
}

From source file:irisdriver.IrisDriver.java

/**
 * @param args the command line arguments
 */// ww w .j  ava2s . c o  m
public static void main(String[] args) {
    //As an example of arguments: sepallength=5.1 sepalwidth=3.5 petallength=1.4 petalwidth=0.2    
    try {
        Hashtable<String, String> values = new Hashtable<String, String>();
        /*Iris irisModel = new Iris();
                
        for(int i = 0; i < args.length; i++) {
        String[] tokens = args[i].split("=");
                
        values.put(tokens[0], tokens[1]);
        }
                
        System.out.println("Classification: " + irisModel.classifySpecies(values));*/

        //Loading the model
        String pathModel = "";
        String pathTestSet = "";
        JFileChooser chooserModel = new JFileChooser();
        chooserModel.setCurrentDirectory(new java.io.File("."));
        chooserModel.setDialogTitle("Choose the model");
        chooserModel.setFileSelectionMode(JFileChooser.FILES_AND_DIRECTORIES);
        chooserModel.setAcceptAllFileFilterUsed(true);

        if (chooserModel.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) {
            File filePathModel = chooserModel.getSelectedFile();
            pathModel = filePathModel.getPath();

            Iris irisModel = new Iris(pathModel);

            //Loading the model
            JFileChooser chooserTestSet = new JFileChooser();
            chooserTestSet.setDialogTitle("Choose TEST SET");
            chooserTestSet.setFileSelectionMode(JFileChooser.FILES_AND_DIRECTORIES);
            chooserTestSet.setAcceptAllFileFilterUsed(true);

            //Loading the testing dataset
            if (chooserTestSet.showOpenDialog(null) == JFileChooser.APPROVE_OPTION) {
                File filePathTestSet = chooserTestSet.getSelectedFile();
                pathTestSet = filePathTestSet.getPath();

                //WRITTING THE OUTPUT:
                BufferedWriter writer = new BufferedWriter(new FileWriter("D:\\output_file.txt"));

                //Transforming the data set into pairs attribute-value
                ConverterUtils.DataSource unlabeledSource = new ConverterUtils.DataSource(pathTestSet);
                Instances unlabeledData = unlabeledSource.getDataSet();
                if (unlabeledData.classIndex() == -1) {
                    unlabeledData.setClassIndex(unlabeledData.numAttributes() - 1);
                }

                for (int i = 0; i < unlabeledData.numInstances(); i++) {
                    Instance ins = unlabeledData.instance(i);

                    //ins.numAttributes()-1 --> not to include the label
                    for (int j = 0; j < ins.numAttributes() - 1; j++) {

                        String attrib = ins.attribute(j).name();
                        double val = ins.value(ins.attribute(j));

                        values.put(attrib, String.valueOf(val));

                    }

                    String predictedLabel = irisModel.classifySpecies(values);
                    System.out.println("Classification: " + predictedLabel);
                    values.clear();

                    //Writting the results in a txt
                    writer.write("The label is: " + predictedLabel);

                    //writer.newLine();

                    //writers.write("The error rate of the prediction is : " + eval.errorRate());

                    //writer.newLine();

                }

                writer.flush();
                writer.close();

            }

        }

    } catch (Exception ex) {
        Logger.getLogger(IrisDriver.class.getName()).log(Level.SEVERE, null, ex);
    }

}

From source file:j48.BinC45Split.java

License:Open Source License

/**
 * Sets distribution associated with model.
 *//*from w w  w .jav  a2  s  . co  m*/
public void resetDistribution(Instances data) throws Exception {

    Instances insts = new Instances(data, data.numInstances());
    for (int i = 0; i < data.numInstances(); i++) {
        if (whichSubset(data.instance(i)) > -1) {
            insts.add(data.instance(i));
        }
    }
    Distribution newD = new Distribution(insts, this);
    newD.addInstWithUnknown(data, m_attIndex);
    m_distribution = newD;
}

From source file:j48.C45PruneableClassifierTreeG.java

License:Open Source License

/**
 * Initializes variables for grafting./*w  w w.  j  ava2s  .c  om*/
 * sets up limits array (for numeric attributes) and calls 
 * the recursive function traverseTree.
 *
 * @param data the data for the tree
 * @throws Exception if anything goes wrong
 */
public void doGrafting(Instances data) throws Exception {

    // 2d array for the limits
    double[][] limits = new double[data.numAttributes()][2];
    // 2nd dimension: index 0 == lower limit, index 1 == upper limit
    // initialise to no limit
    for (int i = 0; i < data.numAttributes(); i++) {
        limits[i][0] = Double.NEGATIVE_INFINITY;
        limits[i][1] = Double.POSITIVE_INFINITY;
    }

    // use an index instead of creating new Insances objects all the time
    // instanceIndex[0] == array for weights at leaf
    // instanceIndex[1] == array for weights in atbop
    double[][] instanceIndex = new double[2][data.numInstances()];
    // initialize the weight for each instance
    for (int x = 0; x < data.numInstances(); x++) {
        instanceIndex[0][x] = 1;
        instanceIndex[1][x] = 1; // leaf instances are in atbop
    }

    // first call to graft
    traverseTree(data, instanceIndex, limits, this, 0, -1);
}

From source file:j48.C45PruneableClassifierTreeG.java

License:Open Source License

/**
 * recursive function./*from  w  w  w  . j a v  a2s .com*/
 * if this node is a leaf then calls findGraft, otherwise sorts 
 * the two sets of instances (tracked in iindex array) and calls
 * sortInstances for each of the child nodes (which then calls
 * this method).
 *
 * @param fulldata all instances
 * @param iindex array the tracks the weight of each instance in
 *        the atbop and at the leaf (0.0 if not present)
 * @param limits array specifying current upper/lower limits for numeric atts
 * @param parent the node immediately before the current one
 * @param pL laplace for node, as calculated by parent (in case leaf is empty)
 * @param nodeClass class of node, determined by parent (in case leaf empty)
 */
private void traverseTree(Instances fulldata, double[][] iindex, double[][] limits,
        C45PruneableClassifierTreeG parent, double pL, int nodeClass) throws Exception {

    if (m_isLeaf) {

        findGraft(fulldata, iindex, limits, (ClassifierTree) parent, pL, nodeClass);

    } else {

        // traverse each branch
        for (int i = 0; i < localModel().numSubsets(); i++) {

            double[][] newiindex = new double[2][fulldata.numInstances()];
            for (int x = 0; x < 2; x++)
                System.arraycopy(iindex[x], 0, newiindex[x], 0, iindex[x].length);
            sortInstances(fulldata, newiindex, limits, i);
        }
    }
}

From source file:j48.C45PruneableClassifierTreeG.java

License:Open Source License

/**
 * finds new nodes that improve accuracy and grafts them onto the tree
 *
 * @param fulldata the instances in whole trainset
 * @param iindex records num tests each instance has failed up to this node
 * @param limits the upper/lower limits for numeric attributes
 * @param parent the node immediately before the current one
 * @param pLaplace laplace for leaf, calculated by parent (in case leaf empty)
 * @param pLeafClass class of leaf, determined by parent (in case leaf empty)
 *//*from w ww .  j  a v  a 2 s . c o m*/
private void findGraft(Instances fulldata, double[][] iindex, double[][] limits, ClassifierTree parent,
        double pLaplace, int pLeafClass) throws Exception {

    // get the class for this leaf
    int leafClass = (m_isEmpty) ? pLeafClass : localModel().distribution().maxClass();

    // get the laplace value for this leaf
    double leafLaplace = (m_isEmpty) ? pLaplace : laplaceLeaf(leafClass);

    // sort the instances into those at the leaf, those in atbop, and discarded
    Instances l = new Instances(fulldata, fulldata.numInstances());
    Instances n = new Instances(fulldata, fulldata.numInstances());
    int lcount = 0;
    int acount = 0;
    for (int x = 0; x < fulldata.numInstances(); x++) {
        if (iindex[0][x] <= 0 && iindex[1][x] <= 0)
            continue;
        if (iindex[0][x] != 0) {
            l.add(fulldata.instance(x));
            l.instance(lcount).setWeight(iindex[0][x]);
            // move instance's weight in iindex to same index as in l
            iindex[0][lcount++] = iindex[0][x];
        }
        if (iindex[1][x] > 0) {
            n.add(fulldata.instance(x));
            n.instance(acount).setWeight(iindex[1][x]);
            // move instance's weight in iindex to same index as in n
            iindex[1][acount++] = iindex[1][x];
        }
    }

    boolean graftPossible = false;
    double[] classDist = new double[n.numClasses()];
    for (int x = 0; x < n.numInstances(); x++) {
        if (iindex[1][x] > 0 && !n.instance(x).classIsMissing())
            classDist[(int) n.instance(x).classValue()] += iindex[1][x];
    }

    for (int cVal = 0; cVal < n.numClasses(); cVal++) {
        double theLaplace = (classDist[cVal] + 1.0) / (classDist[cVal] + 2.0);
        if (cVal != leafClass && (theLaplace > leafLaplace)
                && (biprob(classDist[cVal], classDist[cVal], leafLaplace) > m_BiProbCrit)) {
            graftPossible = true;
            break;
        }
    }

    if (!graftPossible) {
        return;
    }

    // 1. Initialize to {} a set of tuples t containing potential tests
    ArrayList t = new ArrayList();

    // go through each attribute
    for (int a = 0; a < n.numAttributes(); a++) {
        if (a == n.classIndex())
            continue; // skip the class

        // sort instances in atbop by $a
        int[] sorted = sortByAttribute(n, a);

        // 2. For each continuous attribute $a:
        if (n.attribute(a).isNumeric()) {

            // find min and max values for this attribute at the leaf
            boolean prohibited = false;
            double minLeaf = Double.POSITIVE_INFINITY;
            double maxLeaf = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < l.numInstances(); i++) {
                if (l.instance(i).isMissing(a)) {
                    if (l.instance(i).classValue() == leafClass) {
                        prohibited = true;
                        break;
                    }
                }
                double value = l.instance(i).value(a);
                if (!m_relabel || l.instance(i).classValue() == leafClass) {
                    if (value < minLeaf)
                        minLeaf = value;
                    if (value > maxLeaf)
                        maxLeaf = value;
                }
            }
            if (prohibited) {
                continue;
            }

            // (a) find values of
            //    $n: instances in atbop (already have that, actually)
            //    $v: a value for $a that exists for a case in the atbop, where
            //       $v is < the min value for $a for a case at the leaf which
            //       has the class $c, and $v is > the lowerlimit of $a at
            //       the leaf.
            //       (note: error in original paper stated that $v must be
            //       smaller OR EQUAL TO the min value).
            //    $k: $k is a class
            //  that maximize L' = Laplace({$x: $x contained in cases($n)
            //    & value($a,$x) <= $v & value($a,$x) > lowerlim($l,$a)}, $k).
            double minBestClass = Double.NaN;
            double minBestLaplace = leafLaplace;
            double minBestVal = Double.NaN;
            double minBestPos = Double.NaN;
            double minBestTotal = Double.NaN;
            double[][] minBestCounts = null;
            double[][] counts = new double[2][n.numClasses()];
            for (int x = 0; x < n.numInstances(); x++) {
                if (n.instance(sorted[x]).isMissing(a))
                    break; // missing are sorted to end: no more valid vals

                double theval = n.instance(sorted[x]).value(a);
                if (m_Debug)
                    System.out.println("\t " + theval);

                if (theval <= limits[a][0]) {
                    if (m_Debug)
                        System.out.println("\t  <= lowerlim: continuing...");
                    continue;
                }
                // note: error in paper would have this read "theVal > minLeaf)
                if (theval >= minLeaf) {
                    if (m_Debug)
                        System.out.println("\t  >= minLeaf; breaking...");
                    break;
                }
                counts[0][(int) n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]];

                if (x != n.numInstances() - 1) {
                    int z = x + 1;
                    while (z < n.numInstances() && n.instance(sorted[z]).value(a) == theval) {
                        z++;
                        x++;
                        counts[0][(int) n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]];
                    }
                }

                // work out the best laplace/class (for <= theval)
                double total = Utils.sum(counts[0]);
                for (int c = 0; c < n.numClasses(); c++) {
                    double temp = (counts[0][c] + 1.0) / (total + 2.0);
                    if (temp > minBestLaplace) {
                        minBestPos = counts[0][c];
                        minBestTotal = total;
                        minBestLaplace = temp;
                        minBestClass = c;
                        minBestCounts = copyCounts(counts);

                        minBestVal = (x == n.numInstances() - 1) ? theval
                                : ((theval + n.instance(sorted[x + 1]).value(a)) / 2.0);
                    }
                }
            }

            // (b) add to t tuple <n,a,v,k,L',"<=">
            if (!Double.isNaN(minBestVal) && biprob(minBestPos, minBestTotal, leafLaplace) > m_BiProbCrit) {
                GraftSplit gsplit = null;
                try {
                    gsplit = new GraftSplit(a, minBestVal, 0, leafClass, minBestCounts);
                } catch (Exception e) {
                    System.err.println("graftsplit error: " + e.getMessage());
                    System.exit(1);
                }
                t.add(gsplit);
            }
            // free space
            minBestCounts = null;

            // (c) find values of
            //    n: instances in atbop (already have that, actually)
            //    $v: a value for $a that exists for a case in the atbop, where
            //       $v is > the max value for $a for a case at the leaf which
            //       has the class $c, and $v is <= the upperlimit of $a at
            //       the leaf.
            //    k: k is a class
            //   that maximize L' = Laplace({x: x contained in cases(n)
            //       & value(a,x) > v & value(a,x) <= upperlim(l,a)}, k).
            double maxBestClass = -1;
            double maxBestLaplace = leafLaplace;
            double maxBestVal = Double.NaN;
            double maxBestPos = Double.NaN;
            double maxBestTotal = Double.NaN;
            double[][] maxBestCounts = null;
            for (int c = 0; c < n.numClasses(); c++) { // zero the counts
                counts[0][c] = 0;
                counts[1][c] = 0; // shouldn't need to do this ...
            }

            // check smallest val for a in atbop is < upper limit
            if (n.numInstances() >= 1 && n.instance(sorted[0]).value(a) < limits[a][1]) {
                for (int x = n.numInstances() - 1; x >= 0; x--) {
                    if (n.instance(sorted[x]).isMissing(a))
                        continue;

                    double theval = n.instance(sorted[x]).value(a);
                    if (m_Debug)
                        System.out.println("\t " + theval);

                    if (theval > limits[a][1]) {
                        if (m_Debug)
                            System.out.println("\t  >= upperlim; continuing...");
                        continue;
                    }
                    if (theval <= maxLeaf) {
                        if (m_Debug)
                            System.out.println("\t  < maxLeaf; breaking...");
                        break;
                    }

                    // increment counts
                    counts[1][(int) n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]];

                    if (x != 0 && !n.instance(sorted[x - 1]).isMissing(a)) {
                        int z = x - 1;
                        while (z >= 0 && n.instance(sorted[z]).value(a) == theval) {
                            z--;
                            x--;
                            counts[1][(int) n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]];
                        }
                    }

                    // work out best laplace for > theval
                    double total = Utils.sum(counts[1]);
                    for (int c = 0; c < n.numClasses(); c++) {
                        double temp = (counts[1][c] + 1.0) / (total + 2.0);
                        if (temp > maxBestLaplace) {
                            maxBestPos = counts[1][c];
                            maxBestTotal = total;
                            maxBestLaplace = temp;
                            maxBestClass = c;
                            maxBestCounts = copyCounts(counts);
                            maxBestVal = (x == 0) ? theval
                                    : ((theval + n.instance(sorted[x - 1]).value(a)) / 2.0);
                        }
                    }
                }

                // (d) add to t tuple <n,a,v,k,L',">">
                if (!Double.isNaN(maxBestVal) && biprob(maxBestPos, maxBestTotal, leafLaplace) > m_BiProbCrit) {
                    GraftSplit gsplit = null;
                    try {
                        gsplit = new GraftSplit(a, maxBestVal, 1, leafClass, maxBestCounts);
                    } catch (Exception e) {
                        System.err.println("graftsplit error:" + e.getMessage());
                        System.exit(1);
                    }
                    t.add(gsplit);
                }
            }
        } else { // must be a nominal attribute

            // 3. for each discrete attribute a for which there is no
            //    test at an ancestor of l

            // skip if this attribute has already been used
            if (limits[a][1] == 1) {
                continue;
            }

            boolean[] prohibit = new boolean[l.attribute(a).numValues()];
            for (int aval = 0; aval < n.attribute(a).numValues(); aval++) {
                for (int x = 0; x < l.numInstances(); x++) {
                    if ((l.instance(x).isMissing(a) || l.instance(x).value(a) == aval)
                            && (!m_relabel || (l.instance(x).classValue() == leafClass))) {
                        prohibit[aval] = true;
                        break;
                    }
                }
            }

            // (a) find values of
            //       $n: instances in atbop (already have that, actually)
            //       $v: $v is a value for $a
            //       $k: $k is a class
            //     that maximize L' = Laplace({$x: $x contained in cases($n)
            //           & value($a,$x) = $v}, $k).
            double bestVal = Double.NaN;
            double bestClass = Double.NaN;
            double bestLaplace = leafLaplace;
            double[][] bestCounts = null;
            double[][] counts = new double[2][n.numClasses()];

            for (int x = 0; x < n.numInstances(); x++) {
                if (n.instance(sorted[x]).isMissing(a))
                    continue;

                // zero the counts
                for (int c = 0; c < n.numClasses(); c++)
                    counts[0][c] = 0;

                double theval = n.instance(sorted[x]).value(a);
                counts[0][(int) n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]];

                if (x != n.numInstances() - 1) {
                    int z = x + 1;
                    while (z < n.numInstances() && n.instance(sorted[z]).value(a) == theval) {
                        z++;
                        x++;
                        counts[0][(int) n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]];
                    }
                }

                if (!prohibit[(int) theval]) {
                    // work out best laplace for > theval
                    double total = Utils.sum(counts[0]);
                    bestLaplace = leafLaplace;
                    bestClass = Double.NaN;
                    for (int c = 0; c < n.numClasses(); c++) {
                        double temp = (counts[0][c] + 1.0) / (total + 2.0);
                        if (temp > bestLaplace && biprob(counts[0][c], total, leafLaplace) > m_BiProbCrit) {
                            bestLaplace = temp;
                            bestClass = c;
                            bestVal = theval;
                            bestCounts = copyCounts(counts);
                        }
                    }
                    // add to graft list
                    if (!Double.isNaN(bestClass)) {
                        GraftSplit gsplit = null;
                        try {
                            gsplit = new GraftSplit(a, bestVal, 2, leafClass, bestCounts);
                        } catch (Exception e) {
                            System.err.println("graftsplit error: " + e.getMessage());
                            System.exit(1);
                        }
                        t.add(gsplit);
                    }
                }
            }
            // (b) add to t tuple <n,a,v,k,L',"=">
            // done this already
        }
    }

    // 4. remove from t all tuples <n,a,v,c,L,x> such that L <=
    //    Laplace(cases(l),c) or prob(x,n,Laplace(cases(l),c) <= 0.05
    //      -- checked this constraint prior to adding a tuple --

    // *** step six done before step five for efficiency ***
    // 6. for each <n,a,v,k,L,x> in t ordered on L from highest to lowest
    // order the tuples from highest to lowest laplace
    // (this actually orders lowest to highest)
    Collections.sort(t);

    // 5. remove from t all tuples <n,a,v,c,L,x> such that there is
    //    no tuple <n',a',v',k',L',x'> such that k' != c & L' < L.
    for (int x = 0; x < t.size(); x++) {
        GraftSplit gs = (GraftSplit) t.get(x);
        if (gs.maxClassForSubsetOfInterest() != leafClass) {
            break; // reached a graft with class != leafClass, so stop deleting
        } else {
            t.remove(x);
            x--;
        }
    }

    // if no potential grafts were found, do nothing and return
    if (t.size() < 1) {
        return;
    }

    // create the distributions for each graft
    for (int x = t.size() - 1; x >= 0; x--) {
        GraftSplit gs = (GraftSplit) t.get(x);
        try {
            gs.buildClassifier(l);
            gs.deleteGraftedCases(l); // so they don't go down the other branch
        } catch (Exception e) {
            System.err.println("graftsplit build error: " + e.getMessage());
        }
    }

    // add this stuff to the tree
    ((C45PruneableClassifierTreeG) parent).setDescendents(t, this);
}

From source file:j48.ClassifierSplitModel.java

License:Open Source License

/**
 * Splits the given set of instances into subsets.
 *
 * @exception Exception if something goes wrong
 *///  w  ww .  jav  a2s  .c  o m
public final Instances[] split(Instances data) throws Exception {

    Instances[] instances = new Instances[m_numSubsets];
    double[] weights;
    double newWeight;
    Instance instance;
    int subset, i, j;

    for (j = 0; j < m_numSubsets; j++)
        instances[j] = new Instances((Instances) data, data.numInstances());
    for (i = 0; i < data.numInstances(); i++) {
        instance = ((Instances) data).instance(i);
        weights = weights(instance);
        subset = whichSubset(instance);
        if (subset > -1)
            instances[subset].add(instance);
        else
            for (j = 0; j < m_numSubsets; j++)
                if (Utils.gr(weights[j], 0)) {
                    newWeight = weights[j] * instance.weight();
                    instances[j].add(instance);
                    instances[j].lastInstance().setWeight(newWeight);
                }
    }
    for (j = 0; j < m_numSubsets; j++)
        instances[j].compactify();

    return instances;
}

From source file:j48.GraftSplit.java

License:Open Source License

/**
 * deletes the cases in data that belong to leaf pointed to by
 * the test (i.e. the subset of interest).  this is useful so
 * the instances belonging to that leaf aren't passed down the
 * other branch./*from   ww w .  ja v  a  2 s  . c om*/
 *
 * @param data the instances to delete from
 */
public void deleteGraftedCases(Instances data) {

    int subOfInterest = subsetOfInterest();
    for (int x = 0; x < data.numInstances(); x++) {
        if (whichSubset(data.instance(x)) == subOfInterest) {
            data.delete(x--);
        }
    }
}

From source file:j48.GraftSplit.java

License:Open Source License

/**
 * builds m_graftdistro using the passed data
 *
 * @param data the instances to use when creating the distribution
 *///  ww w .  ja v  a2  s  .c  o m
public void buildClassifier(Instances data) throws Exception {

    // distribution for the graft, not counting cases in atbop, only orig leaf
    m_graftdistro = new Distribution(2, data.numClasses());

    // which subset are we looking at for the graft?
    int subset = subsetOfInterest(); // this is the subset for m_leaf

    double thisNodeCount = 0;
    double knownCases = 0;
    boolean allKnown = true;
    // populate distribution
    for (int x = 0; x < data.numInstances(); x++) {
        Instance instance = data.instance(x);
        if (instance.isMissing(m_attIndex)) {
            allKnown = false;
            continue;
        }
        knownCases += instance.weight();
        int subst = whichSubset(instance);
        if (subst == -1)
            continue;
        m_graftdistro.add(subst, instance);
        if (subst == subset) { // instance belongs at m_leaf
            thisNodeCount += instance.weight();
        }
    }
    double factor = (knownCases == 0) ? (1.0 / (double) 2.0) : (thisNodeCount / knownCases);
    if (!allKnown) {
        for (int x = 0; x < data.numInstances(); x++) {
            if (data.instance(x).isMissing(m_attIndex)) {
                Instance instance = data.instance(x);
                int subst = whichSubset(instance);
                if (subst == -1)
                    continue;
                instance.setWeight(instance.weight() * factor);
                m_graftdistro.add(subst, instance);
            }
        }
    }

    // if there are no cases at the leaf, make sure the desired
    // class is chosen, by setting counts to 0.01
    if (m_graftdistro.perBag(subset) == 0) {
        double[] counts = new double[data.numClasses()];
        counts[m_maxClass] = 0.01;
        m_graftdistro.add(subset, counts);
    }
    if (m_graftdistro.perBag((subset == 0) ? 1 : 0) == 0) {
        double[] counts = new double[data.numClasses()];
        counts[(int) m_otherLeafMaxClass] = 0.01;
        m_graftdistro.add((subset == 0) ? 1 : 0, counts);
    }
}

From source file:j48.NBTreeModelSelection.java

License:Open Source License

/**
 * Selects NBTree-type split for the given dataset.
 *//*w w w . ja va2 s.co m*/
public final ClassifierSplitModel selectModel(Instances data) {

    double globalErrors = 0;

    double minResult;
    double currentResult;
    NBTreeSplit[] currentModel;
    NBTreeSplit bestModel = null;
    NBTreeNoSplit noSplitModel = null;
    int validModels = 0;
    boolean multiVal = true;
    Distribution checkDistribution;
    Attribute attribute;
    double sumOfWeights;
    int i;

    try {
        // build the global model at this node
        noSplitModel = new NBTreeNoSplit();
        noSplitModel.buildClassifier(data);
        if (data.numInstances() < 5) {
            return noSplitModel;
        }

        // evaluate it
        globalErrors = noSplitModel.getErrors();
        if (globalErrors == 0) {
            return noSplitModel;
        }

        // Check if all Instances belong to one class or if not
        // enough Instances to split.
        checkDistribution = new Distribution(data);
        if (Utils.sm(checkDistribution.total(), m_minNoObj) || Utils.eq(checkDistribution.total(),
                checkDistribution.perClass(checkDistribution.maxClass()))) {
            return noSplitModel;
        }

        // Check if all attributes are nominal and have a 
        // lot of values.
        if (m_allData != null) {
            Enumeration enu = data.enumerateAttributes();
            while (enu.hasMoreElements()) {
                attribute = (Attribute) enu.nextElement();
                if ((attribute.isNumeric()) || (Utils.sm((double) attribute.numValues(),
                        (0.3 * (double) m_allData.numInstances())))) {
                    multiVal = false;
                    break;
                }
            }
        }

        currentModel = new NBTreeSplit[data.numAttributes()];
        sumOfWeights = data.sumOfWeights();

        // For each attribute.
        for (i = 0; i < data.numAttributes(); i++) {

            // Apart from class attribute.
            if (i != (data).classIndex()) {

                // Get models for current attribute.
                currentModel[i] = new NBTreeSplit(i, m_minNoObj, sumOfWeights);
                currentModel[i].setGlobalModel(noSplitModel);
                currentModel[i].buildClassifier(data);

                // Check if useful split for current attribute
                // exists and check for enumerated attributes with 
                // a lot of values.
                if (currentModel[i].checkModel()) {
                    validModels++;
                }
            } else {
                currentModel[i] = null;
            }
        }

        // Check if any useful split was found.
        if (validModels == 0) {
            return noSplitModel;
        }

        // Find "best" attribute to split on.
        minResult = globalErrors;
        for (i = 0; i < data.numAttributes(); i++) {
            if ((i != (data).classIndex()) && (currentModel[i].checkModel())) {
                /*  System.err.println("Errors for "+data.attribute(i).name()+" "+
                    currentModel[i].getErrors()); */
                if (currentModel[i].getErrors() < minResult) {
                    bestModel = currentModel[i];
                    minResult = currentModel[i].getErrors();
                }
            }
        }
        //      System.exit(1);
        // Check if useful split was found.

        if (((globalErrors - minResult) / globalErrors) < 0.05) {
            return noSplitModel;
        }

        /*      if (bestModel == null) {
        System.err.println("This shouldn't happen! glob : "+globalErrors+
              " minRes : "+minResult);
        System.exit(1);
        } */
        // Set the global model for the best split
        //      bestModel.setGlobalModel(noSplitModel);

        return bestModel;
    } catch (Exception e) {
        e.printStackTrace();
    }
    return null;
}