Example usage for weka.core Instances numInstances

List of usage examples for weka.core Instances numInstances

Introduction

In this page you can find the example usage for weka.core Instances numInstances.

Prototype


publicint numInstances() 

Source Link

Document

Returns the number of instances in the dataset.

Usage

From source file:CJWeka.java

License:Open Source License

/**
 * This function sets what the m_numeric flag to represent the passed class
 * it also performs the normalization of the attributes if applicable
 * and sets up the info to normalize the class. (note that regardless of
 * the options it will fill an array with the range and base, set to
 * normalize all attributes and the class to be between -1 and 1)
 * @param inst the instances./*from ww w . j  a  v  a  2s  .c  o  m*/
 * @return The modified instances. This needs to be done. If the attributes
 * are normalized then deep copies will be made of all the instances which
 * will need to be passed back out.
 */
private Instances setClassType(Instances inst) throws Exception {
    if (inst != null) {
        // x bounds
        double min = Double.POSITIVE_INFINITY;
        double max = Double.NEGATIVE_INFINITY;
        double value;
        m_attributeRanges = new double[inst.numAttributes()];
        m_attributeBases = new double[inst.numAttributes()];
        for (int noa = 0; noa < inst.numAttributes(); noa++) {
            min = Double.POSITIVE_INFINITY;
            max = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < inst.numInstances(); i++) {
                if (!inst.instance(i).isMissing(noa)) {
                    value = inst.instance(i).value(noa);
                    if (value < min) {
                        min = value;
                    }
                    if (value > max) {
                        max = value;
                    }
                }
            }

            m_attributeRanges[noa] = (max - min) / 2;
            m_attributeBases[noa] = (max + min) / 2;
            if (noa != inst.classIndex() && m_normalizeAttributes) {
                for (int i = 0; i < inst.numInstances(); i++) {
                    if (m_attributeRanges[noa] != 0) {
                        inst.instance(i).setValue(noa,
                                (inst.instance(i).value(noa) - m_attributeBases[noa]) / m_attributeRanges[noa]);
                    } else {
                        inst.instance(i).setValue(noa, inst.instance(i).value(noa) - m_attributeBases[noa]);
                    }
                }
            }
        }
        if (inst.classAttribute().isNumeric()) {
            m_numeric = true;
        } else {
            m_numeric = false;
        }
    }
    return inst;
}

From source file:MultiClassClassifier.java

License:Open Source License

/**
 * Builds the classifiers.//from  w  w w  . j  av a2 s .c o  m
 *
 * @param insts the training data.
 * @throws Exception if a classifier can't be built
 */
public void buildClassifier(Instances insts) throws Exception {

    Instances newInsts;

    // can classifier handle the data?
    getCapabilities().testWithFail(insts);

    // remove instances with missing class
    insts = new Instances(insts);
    insts.deleteWithMissingClass();

    if (m_Classifier == null) {
        throw new Exception("No base classifier has been set!");
    }
    m_ZeroR = new ZeroR();
    m_ZeroR.buildClassifier(insts);

    m_TwoClassDataset = null;

    int numClassifiers = insts.numClasses();
    if (numClassifiers <= 2) {

        m_Classifiers = Classifier.makeCopies(m_Classifier, 1);
        m_Classifiers[0].buildClassifier(insts);

        m_ClassFilters = null;

    } else if (m_Method == METHOD_1_AGAINST_1) {
        // generate fastvector of pairs
        FastVector pairs = new FastVector();
        for (int i = 0; i < insts.numClasses(); i++) {
            for (int j = 0; j < insts.numClasses(); j++) {
                if (j <= i)
                    continue;
                int[] pair = new int[2];
                pair[0] = i;
                pair[1] = j;
                pairs.addElement(pair);
            }
        }

        numClassifiers = pairs.size();
        m_Classifiers = Classifier.makeCopies(m_Classifier, numClassifiers);
        m_ClassFilters = new Filter[numClassifiers];
        m_SumOfWeights = new double[numClassifiers];

        // generate the classifiers
        for (int i = 0; i < numClassifiers; i++) {
            RemoveWithValues classFilter = new RemoveWithValues();
            classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
            classFilter.setModifyHeader(true);
            classFilter.setInvertSelection(true);
            classFilter.setNominalIndicesArr((int[]) pairs.elementAt(i));
            Instances tempInstances = new Instances(insts, 0);
            tempInstances.setClassIndex(-1);
            classFilter.setInputFormat(tempInstances);
            newInsts = Filter.useFilter(insts, classFilter);
            if (newInsts.numInstances() > 0) {
                newInsts.setClassIndex(insts.classIndex());
                m_Classifiers[i].buildClassifier(newInsts);
                m_ClassFilters[i] = classFilter;
                m_SumOfWeights[i] = newInsts.sumOfWeights();
            } else {
                m_Classifiers[i] = null;
                m_ClassFilters[i] = null;
            }
        }

        // construct a two-class header version of the dataset
        m_TwoClassDataset = new Instances(insts, 0);
        int classIndex = m_TwoClassDataset.classIndex();
        m_TwoClassDataset.setClassIndex(-1);
        m_TwoClassDataset.deleteAttributeAt(classIndex);
        FastVector classLabels = new FastVector();
        classLabels.addElement("class0");
        classLabels.addElement("class1");
        m_TwoClassDataset.insertAttributeAt(new Attribute("class", classLabels), classIndex);
        m_TwoClassDataset.setClassIndex(classIndex);

    } else {
        // use error correcting code style methods
        Code code = null;
        switch (m_Method) {
        case METHOD_ERROR_EXHAUSTIVE:
            code = new ExhaustiveCode(numClassifiers);
            break;
        case METHOD_ERROR_RANDOM:
            code = new RandomCode(numClassifiers, (int) (numClassifiers * m_RandomWidthFactor), insts);
            break;
        case METHOD_1_AGAINST_ALL:
            code = new StandardCode(numClassifiers);
            break;
        default:
            throw new Exception("Unrecognized correction code type");
        }
        numClassifiers = code.size();
        m_Classifiers = Classifier.makeCopies(m_Classifier, numClassifiers);
        m_ClassFilters = new MakeIndicator[numClassifiers];
        for (int i = 0; i < m_Classifiers.length; i++) {
            m_ClassFilters[i] = new MakeIndicator();
            MakeIndicator classFilter = (MakeIndicator) m_ClassFilters[i];
            classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
            classFilter.setValueIndices(code.getIndices(i));
            classFilter.setNumeric(false);
            classFilter.setInputFormat(insts);
            newInsts = Filter.useFilter(insts, m_ClassFilters[i]);
            m_Classifiers[i].buildClassifier(newInsts);
        }
    }
    m_ClassAttribute = insts.classAttribute();
}

From source file:MultiClassClassifier.java

License:Open Source License

public double[][] calibratedDistributionForTestInstances(Instances test) throws Exception {
    double[][] binProbs = new double[m_Classifiers.length][test.numInstances()];
    double[][] calibratedProbs = new double[m_Classifiers.length][test.numInstances()];
    boolean[] target = new boolean[test.numInstances()];
    int prior1 = 0;
    int prior0 = 0;
    if (m_Classifiers.length == 1) {
        for (int i = 0; i < test.numInstances(); i++) {
            Instance inst = test.instance(i);
            //m_ClassFilters[0].input(inst);
            //m_ClassFilters[0].batchFinished();
            //Instance filteredInst = m_ClassFilters[i].output();

            //binProbs[0][i] = (200*m_Classifiers[0].distributionForInstance(inst)[1])-100;
            binProbs[0][i] = m_Classifiers[0].distributionForInstance(inst)[1];
            if (target[i] = inst.classValue() == 1.0)
                prior1++;//  ww  w  .  j  av a  2 s.  c om
            else
                prior0++;
        }
        calibratedProbs[0] = sigTraining(binProbs[0], target, prior1, prior0);
        return calibratedProbs;
    } else {

        double[] probs = new double[test.classAttribute().numValues()];

        if (m_Method == METHOD_1_AGAINST_1) {
            throw new Exception("Not implemented for Method 1 against 1");
            /*double[][] r = new double[inst.numClasses()][inst.numClasses()];
            double[][] n = new double[inst.numClasses()][inst.numClasses()];
                    
            for(int i = 0; i < m_ClassFilters.length; i++) 
            {
               if (m_Classifiers[i] != null) {
                  Instance tempInst = (Instance)inst.copy(); 
                  tempInst.setDataset(m_TwoClassDataset);
                  double [] current = m_Classifiers[i].distributionForInstance(tempInst);  
                  Range range = new Range(((RemoveWithValues)m_ClassFilters[i])
                    .getNominalIndices());
                  range.setUpper(m_ClassAttribute.numValues());
                  int[] pair = range.getSelection();
                  if (m_pairwiseCoupling && inst.numClasses() > 2) {
                     r[pair[0]][pair[1]] = current[0];
                     n[pair[0]][pair[1]] = m_SumOfWeights[i];
                  }
                  else {
                     if (current[0] > current[1]) {
             probs[pair[0]] += 1.0;
                     }
                     else {
             probs[pair[1]] += 1.0;
                     }
                  }
               }
            }
            if (m_pairwiseCoupling && inst.numClasses() > 2) {
              return pairwiseCoupling(n, r);
            }*/
        } else {
            // error correcting style methods
            for (int i = 0; i < m_ClassFilters.length; i++) {
                prior1 = 0;
                prior0 = 0;
                for (int k = 0; k < test.numInstances(); k++) {
                    Instance inst = test.instance(k);
                    m_ClassFilters[i].input(inst);
                    m_ClassFilters[i].batchFinished();
                    Instance filteredInst = m_ClassFilters[i].output();
                    //binProbs[i][k] = (200*m_Classifiers[i].distributionForInstance(filteredInst)[1]) - 100;
                    binProbs[i][k] = m_Classifiers[i].distributionForInstance(filteredInst)[1];

                    //System.out.println(binProbs[i][k] + " " + inst.classValue());
                    //System.out.println("Class value: " + filteredInst.classValue() + " " + filteredInst.stringValue(filteredInst.numAttributes()-1) + " " + m_Classifiers[i].distributionForInstance(filteredInst)[0] + " " + m_Classifiers[i].distributionForInstance(filteredInst)[1]);
                    if (target[k] = (filteredInst.classValue() == 1.0))
                        prior1++;
                    else
                        prior0++;

                    /*for (int j = 0; j < m_ClassAttribute.numValues(); j++)
                    {
                       if (((MakeIndicator)m_ClassFilters[i]).getValueRange().isInRange(j))
                       {
                          binProbs[j] += current[1];
                       }
                       else 
                       {
                          binProbs[j] += current[0];
                       }
                    }*/
                }
                calibratedProbs[i] = sigTraining(binProbs[i], target, prior1, prior0);
            }
            /*         for (int k = 0; k < test.numInstances(); k++)   
                     {
                        for (int i =0; i < 3; i++)
                        System.out.println(i + " " + k + " cal: " + calibratedProbs[i][k] + " " + binProbs[i][k]);
                     }
               */ }
    }
    for (int i = 0; i < test.numInstances(); i++) {
        double sum = 0;
        for (int j = 0; j < m_Classifiers.length; j++) {
            sum += calibratedProbs[j][i];
        }
        for (int j = 0; j < m_Classifiers.length; j++)
            calibratedProbs[j][i] /= sum;
    }
    return calibratedProbs;
    /*
    if (Utils.gr(Utils.sum(probs), 0)) 
    {
      Utils.normalize(probs);
      return probs;
    }
    else {
       return m_ZeroR.distributionForInstance(inst);
    }*/
}

From source file:GrowTree.java

Attribute bestSplit(Instances D) {
    double imin = 1.0;
    Attribute fbest = null;//  w  ww .  j  a va 2s .com
    Enumeration enat = D.enumerateAttributes();
    while (enat.hasMoreElements()) {
        Attribute a = (Attribute) enat.nextElement();
        //split D into subsets d1 to dn based on values vi based on features
        Instances[] split = new Instances[a.numValues()];
        for (int i = 0; i < a.numValues(); i++) {
            split[i] = new Instances(D, D.numInstances());
        }
        Enumeration x = D.enumerateInstances();
        while (x.hasMoreElements()) {
            Instance in = (Instance) x.nextElement();
            split[(int) in.value(a)].add(in);
        }
        for (int i = 0; i < split.length; i++) {
            split[i].compactify();
        }
        for (int i = 0; i < a.numValues(); i++) {
            if (imp(split[i]) < imin) {
                imin = imp(split[i]);
                fbest = a; //evaluate the best feature to make root
            }
        }
    }
    return fbest;

}

From source file:classifyfromimage1.java

private void jButton1ActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_jButton1ActionPerformed
    selectWindow(this.name3);
    this.name3 = IJ.getImage().getTitle();
    this.name4 = this.name3.replaceFirst("[.][^.]+$", "");
    RoiManager rm = RoiManager.getInstance();
    IJ.run("Duplicate...", this.name4);
    IJ.run("Set Measurements...", "area perimeter fit shape limit scientific redirect=None decimal=5");
    selectWindow(this.name3);
    IJ.run("Subtract Background...", "rolling=1.5");
    IJ.run("Enhance Contrast...", "saturated=25 equalize");
    IJ.run("Subtract Background...", "rolling=1.5");
    IJ.run("Convolve...",
            "text1=[-1 -3 -4 -3 -1\n-3 0 6 0 -3\n-4 6 50 6 -4\n-3 0 6 0 -3\n-1 -3 -4 -3 -1\n] normalize");
    IJ.run("8-bit", "");
    IJ.run("Restore Selection", "");
    IJ.run("Make Binary", "");
    Prefs.blackBackground = false;//  www. ja va 2  s  . co  m
    IJ.run("Convert to Mask", "");
    IJ.run("Restore Selection", "");
    this.valor1 = this.interval3.getText();
    this.valor2 = this.interval4.getText();
    this.text = "size=" + this.valor1 + "-" + this.valor2
            + " pixel show=Outlines display include summarize add";
    IJ.saveAs("tif", this.name3 + "_processed");
    String dest_filename1, dest_filename2, full;
    selectWindow("Results");
    //dest_filename1 = this.name2 + "_complete.txt";
    dest_filename2 = this.name3 + "_complete.csv";
    //IJ.saveAs("Results", prova + File.separator + dest_filename1);
    IJ.run("Input/Output...", "jpeg=85 gif=-1 file=.csv copy_row save_column save_row");
    //IJ.saveAs("Results", dir + File.separator + dest_filename2);
    IJ.saveAs("Results", this.name3 + "_complete.csv");
    IJ.run("Restore Selection");
    IJ.run("Clear Results");

    try {
        CSVLoader loader = new CSVLoader();
        loader.setSource(new File(this.name3 + "_complete.csv"));
        Instances data = loader.getDataSet();
        System.out.println(data);

        // save ARFF
        String arffile = this.name3 + ".arff";
        System.out.println(arffile);
        ArffSaver saver = new ArffSaver();
        saver.setInstances(data);
        saver.setFile(new File(arffile));
        saver.writeBatch();
    } catch (IOException ex) {
        Logger.getLogger(MachinLearningInterface.class.getName()).log(Level.SEVERE, null, ex);
    }

    Instances data;
    try {
        data = new Instances(new BufferedReader(new FileReader(this.name3 + ".arff")));
        Instances newData = null;
        Add filter;
        newData = new Instances(data);
        filter = new Add();
        filter.setAttributeIndex("last");
        filter.setNominalLabels(txtlabel.getText());
        filter.setAttributeName(txtpath2.getText());
        filter.setInputFormat(newData);
        newData = Filter.useFilter(newData, filter);
        System.out.print(newData);
        Vector vec = new Vector();
        newData.setClassIndex(newData.numAttributes() - 1);

        if (!newData.equalHeaders(newData)) {
            throw new IllegalArgumentException("Train and test are not compatible!");
        }

        Classifier cls = (Classifier) weka.core.SerializationHelper.read(txtpath.getText());
        System.out.println("PROVANT MODEL.classifyInstance");
        for (int i = 0; i < newData.numInstances(); i++) {
            double pred = cls.classifyInstance(newData.instance(i));
            double[] dist = cls.distributionForInstance(newData.instance(i));
            System.out.print((i + 1) + " - ");
            System.out.print(newData.classAttribute().value((int) pred) + " - ");
            //txtarea2.setText(Utils.arrayToString(dist));

            System.out.println(Utils.arrayToString(dist));

            vec.add(newData.classAttribute().value((int) pred));
            //txtarea2.append(Utils.arrayToString(dist));
            classif.add(newData.classAttribute().value((int) pred));
        }

        classif.removeAll(Arrays.asList("", null));
        System.out.println(classif);
        String vecstring = "";

        for (Object s : classif) {
            vecstring += s + ",";
            System.out.println("Hola " + vecstring);
        }
        Map<String, Integer> seussCount = new HashMap<String, Integer>();
        for (String t : classif) {
            Integer i = seussCount.get(t);
            if (i == null) {
                i = 0;
            }
            seussCount.put(t, i + 1);
        }
        String s = vecstring;
        int counter = 0;
        for (int i = 0; i < s.length(); i++) {
            if (s.charAt(i) == '$') {
                counter++;
            }
        }
        System.out.println(seussCount);
        System.out.println("hola " + counter++);
        IJ.showMessage("Your file:" + this.name3 + "arff" + "\n is composed by" + seussCount);
        txtpath2.setText("Your file:" + this.name3 + "arff" + "\n is composed by" + seussCount);
        A_MachineLearning nf2 = new A_MachineLearning();
        A_MachineLearning.txtresult2.append(this.txtpath2.getText());
        nf2.setVisible(true);

    } catch (Exception ex) {
        Logger.getLogger(MachinLearningInterface.class.getName()).log(Level.SEVERE, null, ex);
    }

    IJ.run("Close All", "");

    if (WindowManager.getFrame("Results") != null) {
        IJ.selectWindow("Results");
        IJ.run("Close");
    }
    if (WindowManager.getFrame("Summary") != null) {
        IJ.selectWindow("Summary");
        IJ.run("Close");
    }
    if (WindowManager.getFrame("Results") != null) {
        IJ.selectWindow("Results");
        IJ.run("Close");
    }
    if (WindowManager.getFrame("ROI Manager") != null) {
        IJ.selectWindow("ROI Manager");
        IJ.run("Close");
    }

    setVisible(false);
    dispose();// TODO add your handling code here:
    // TODO add your handling code here:
}

From source file:SMO.java

License:Open Source License

/**
 * Method for building the classifier. Implements a one-against-one
 * wrapper for multi-class problems.//  ww  w.j a v a  2  s .c om
 *
 * @param insts the set of training instances
 * @throws Exception if the classifier can't be built successfully
 */
public void buildClassifier(Instances insts) throws Exception {

    if (!m_checksTurnedOff) {
        // can classifier handle the data?
        getCapabilities().testWithFail(insts);

        // remove instances with missing class
        insts = new Instances(insts);
        insts.deleteWithMissingClass();

        /* Removes all the instances with weight equal to 0.
         MUST be done since condition (8) of Keerthi's paper 
         is made with the assertion Ci > 0 (See equation (3a). */
        Instances data = new Instances(insts, insts.numInstances());
        for (int i = 0; i < insts.numInstances(); i++) {
            if (insts.instance(i).weight() > 0)
                data.add(insts.instance(i));
        }
        if (data.numInstances() == 0) {
            throw new Exception("No training instances left after removing " + "instances with weight 0!");
        }
        insts = data;
    }

    if (!m_checksTurnedOff) {
        m_Missing = new ReplaceMissingValues();
        m_Missing.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Missing);
    } else {
        m_Missing = null;
    }

    if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) {
        boolean onlyNumeric = true;
        if (!m_checksTurnedOff) {
            for (int i = 0; i < insts.numAttributes(); i++) {
                if (i != insts.classIndex()) {
                    if (!insts.attribute(i).isNumeric()) {
                        onlyNumeric = false;
                        break;
                    }
                }
            }
        }

        if (!onlyNumeric) {
            m_NominalToBinary = new NominalToBinary();
            m_NominalToBinary.setInputFormat(insts);
            insts = Filter.useFilter(insts, m_NominalToBinary);
        } else {
            m_NominalToBinary = null;
        }
    } else {
        m_NominalToBinary = null;
    }

    if (m_filterType == FILTER_STANDARDIZE) {
        m_Filter = new Standardize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else if (m_filterType == FILTER_NORMALIZE) {
        m_Filter = new Normalize();
        m_Filter.setInputFormat(insts);
        insts = Filter.useFilter(insts, m_Filter);
    } else {
        m_Filter = null;
    }

    m_classIndex = insts.classIndex();
    m_classAttribute = insts.classAttribute();
    m_KernelIsLinear = (m_kernel instanceof PolyKernel) && (((PolyKernel) m_kernel).getExponent() == 1.0);

    // Generate subsets representing each class
    Instances[] subsets = new Instances[insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i] = new Instances(insts, insts.numInstances());
    }
    for (int j = 0; j < insts.numInstances(); j++) {
        Instance inst = insts.instance(j);
        subsets[(int) inst.classValue()].add(inst);
    }
    for (int i = 0; i < insts.numClasses(); i++) {
        subsets[i].compactify();
    }

    // Build the binary classifiers
    Random rand = new Random(m_randomSeed);
    m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
        for (int j = i + 1; j < insts.numClasses(); j++) {
            m_classifiers[i][j] = new BinarySMO();
            m_classifiers[i][j].setKernel(Kernel.makeCopy(getKernel()));
            Instances data = new Instances(insts, insts.numInstances());
            for (int k = 0; k < subsets[i].numInstances(); k++) {
                data.add(subsets[i].instance(k));
            }
            for (int k = 0; k < subsets[j].numInstances(); k++) {
                data.add(subsets[j].instance(k));
            }
            data.compactify();
            data.randomize(rand);
            m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed);
        }
    }
}

From source file:SpectralClusterer.java

License:Open Source License

/**
 * Generates a clusterer by the mean of spectral clustering algorithm.
 * /*ww w  .  j ava2 s . c  om*/
 * @param data
 *            set of instances serving as training data
 * @exception Exception
 *                if the clusterer has not been generated successfully
 */
public void buildClusterer(Instances data) throws java.lang.Exception {
    int n = data.numInstances();
    int k = data.numAttributes();
    DoubleMatrix2D w;
    if (useSparseMatrix)
        w = DoubleFactory2D.sparse.make(n, n);
    else
        w = DoubleFactory2D.dense.make(n, n);
    double[][] v1 = new double[n][];
    for (int i = 0; i < n; i++)
        v1[i] = data.instance(i).toDoubleArray();
    v = DoubleFactory2D.dense.make(v1);
    double sigma_sq = sigma * sigma;
    // Sets up similarity matrix
    for (int i = 0; i < n; i++)
        for (int j = i; j < n; j++) {
            double dist = distnorm2(v.viewRow(i), v.viewRow(j));
            if ((r == -1) || (dist < r)) {
                double sim = Math.exp(-(dist * dist) / (2 * sigma_sq));
                w.set(i, j, sim);
                w.set(j, i, sim);
            }
        }

    // Partitions points
    int[][] p = partition(w, alpha_star);

    // Deploys results
    numOfClusters = p.length;
    cluster = new int[n];
    for (int i = 0; i < p.length; i++)
        for (int j = 0; j < p[i].length; j++)
            cluster[p[i][j]] = i;
}

From source file:ID3Chi.java

License:Open Source License

/**
 * Method for building an ID3Chi tree.//from  w  ww.j  a  v a2  s  . c om
 *
 * @param data
 *            the training data
 * @exception Exception
 *                if decision tree can't be built successfully
 */
private void makeTree(Instances data) throws Exception {

    // Check if no instances have reached this node.
    /*
    if (data.numInstances() == 0) {
       m_Attribute = null;
       m_ClassValue = Instance.missingValue();
       m_Distribution = new double[data.numClasses()];
       return;
    }
    /**/
    if (data.numInstances() == 0) {
        SetNullDistribution(data);
    }

    // Compute attribute with maximum information gain.
    double[] infoGains = new double[data.numAttributes()];
    Enumeration attEnum = data.enumerateAttributes();
    double entropyOfAllData = computeEntropy(data);

    while (attEnum.hasMoreElements()) {
        Attribute att = (Attribute) attEnum.nextElement();
        infoGains[att.index()] = computeInfoGain(data, att, entropyOfAllData);
    }
    m_Attribute = data.attribute(Utils.maxIndex(infoGains));

    double chiSquare = computeChiSquare(data, m_Attribute);

    int degreesOfFreedom = m_Attribute.numValues() - 1;
    ChiSquaredDistribution chi = new ChiSquaredDistribution(degreesOfFreedom);
    double threshold = chi.inverseCumulativeProbability(m_confidenceLevel);

    // Make leaf if information gain is zero.
    // Otherwise create successors.
    if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
        MakeALeaf(data);
    } else {
        // Discard unknown values for selected attribute
        //data.deleteWithMissing(m_Attribute);
        Instances[] subset = splitData(data, m_Attribute);

        if (CheckIfCanApplyChiSquare(subset) && (chiSquare <= threshold)) {
            MakeALeaf(data);
            return;
        }

        m_Successors = new ID3Chi[m_Attribute.numValues()];
        for (int j = 0; j < m_Attribute.numValues(); j++) {
            m_Successors[j] = new ID3Chi(this.m_confidenceLevel);
            m_Successors[j].m_Ratio = (double) subset[j].numInstances() / (double) data.numInstances();
            m_Successors[j].makeTree(subset[j]);
        }
    }
}

From source file:ID3Chi.java

License:Open Source License

private void MakeALeaf(Instances data) {

    data.deleteWithMissing(m_Attribute);

    if (data.numInstances() == 0) {
        SetNullDistribution(data);//from  w  w w  .  j  av  a2  s . c  o m
        return;
    }

    m_Distribution = new double[data.numClasses()];
    Enumeration instEnum = data.enumerateInstances();
    while (instEnum.hasMoreElements()) {
        Instance inst = (Instance) instEnum.nextElement();
        m_Distribution[(int) inst.classValue()]++;
    }
    Utils.normalize(m_Distribution);
    m_ClassValue = Utils.maxIndex(m_Distribution);
    m_ClassAttribute = data.classAttribute();

    // set m_Attribute to null to mark this node as a leaf
    m_Attribute = null;
}

From source file:ID3Chi.java

License:Open Source License

/**
 * Computes Chi-Square function for an attribute.
 *
 * @param data//from   w w  w .ja va 2 s  .  c  om
 *            the data for which info gain is to be computed
 * @param att
 *            the attribute
 * @return the chi-square for the given attribute and data
 * @throws Exception
 *             if computation fails
 */
private double computeChiSquare(Instances data, Attribute att) throws Exception {

    double chiSquare = 0;
    double[] classCounts = GetClassCounts(data);
    Instances[] subset = splitData(data, att);
    for (int j = 0; j < att.numValues(); j++) {
        if (subset[j].numInstances() > 0) {
            chiSquare += computeChiSquareForSubset(subset[j], att, classCounts, data.numInstances());
        }
    }
    return chiSquare;
}