Example usage for weka.core Instances classAttribute

List of usage examples for weka.core Instances classAttribute

Introduction

In this page you can find the example usage for weka.core Instances classAttribute.

Prototype


publicAttribute classAttribute() 

Source Link

Document

Returns the class attribute.

Usage

From source file:newdtl.NewID3.java

/**
 * Creates an Id3 tree.//  ww w  .j a  v  a2  s .  c  o  m
 *
 * @param data the training data
 * @exception Exception if tree failed to build
 */
private void makeTree(Instances data) throws Exception {

    // Mengecek apakah tidak terdapat instance dalam node ini
    if (data.numInstances() == 0) {
        splitAttribute = null;
        label = DOUBLE_MISSING_VALUE;
        classDistributions = new double[data.numClasses()]; //???
    } else {
        // Mencari IG maksimum
        double[] infoGains = new double[data.numAttributes()];

        Enumeration attEnum = data.enumerateAttributes();
        while (attEnum.hasMoreElements()) {
            Attribute att = (Attribute) attEnum.nextElement();
            infoGains[att.index()] = computeInfoGain(data, att);
        }

        // cek max IG
        int maxIG = maxIndex(infoGains);
        if (maxIG != -1) {
            splitAttribute = data.attribute(maxIndex(infoGains));
        } else {
            Exception exception = new Exception("array null");
            throw exception;
        }

        // Membuat daun jika IG-nya 0
        if (Double.compare(infoGains[splitAttribute.index()], 0) == 0) {
            splitAttribute = null;

            classDistributions = new double[data.numClasses()];
            for (int i = 0; i < data.numInstances(); i++) {
                Instance inst = (Instance) data.instance(i);
                classDistributions[(int) inst.classValue()]++;
            }

            normalizeClassDistribution();
            label = maxIndex(classDistributions);
            classAttribute = data.classAttribute();
        } else {
            // Membuat tree baru di bawah node ini
            Instances[] splitData = splitData(data, splitAttribute);
            children = new NewID3[splitAttribute.numValues()];
            for (int j = 0; j < splitAttribute.numValues(); j++) {
                children[j] = new NewID3();
                children[j].makeTree(splitData[j]);
            }
        }
    }
}

From source file:newdtl.NewJ48.java

/**
 * Creates a J48 tree./*from w w w.j  ava 2 s  . c o  m*/
 *
 * @param data the training data
 * @exception Exception if tree failed to build
 */
private void makeTree(Instances data) throws Exception {

    // Mengecek apakah tidak terdapat instance dalam node ini
    if (data.numInstances() == 0) {
        splitAttribute = null;
        label = DOUBLE_MISSING_VALUE;
        classDistributions = new double[data.numClasses()];
        isLeaf = true;
    } else {
        // Mencari Gain Ratio maksimum
        double[] gainRatios = new double[data.numAttributes()];
        double[] thresholds = new double[data.numAttributes()];

        Enumeration attEnum = data.enumerateAttributes();
        while (attEnum.hasMoreElements()) {
            Attribute att = (Attribute) attEnum.nextElement();
            double[] result = computeGainRatio(data, att);
            gainRatios[att.index()] = result[0];
            thresholds[att.index()] = result[1];
        }

        splitAttribute = data.attribute(maxIndex(gainRatios));

        if (splitAttribute.isNumeric()) {
            splitThreshold = thresholds[maxIndex(gainRatios)];
        } else {
            splitThreshold = Double.NaN;
        }

        classDistributions = new double[data.numClasses()];
        for (int i = 0; i < data.numInstances(); i++) {
            Instance inst = (Instance) data.instance(i);
            classDistributions[(int) inst.classValue()]++;
        }

        // Membuat daun jika Gain Ratio-nya 0
        if (Double.compare(gainRatios[splitAttribute.index()], 0) == 0) {
            splitAttribute = null;

            label = maxIndex(classDistributions);
            classAttribute = data.classAttribute();
            isLeaf = true;
        } else {
            // Mengecek jika ada missing value
            if (isMissing(data, splitAttribute)) {
                // cari modus
                int index = modusIndex(data, splitAttribute);

                // ubah data yang punya missing value
                Enumeration dataEnum = data.enumerateInstances();
                while (dataEnum.hasMoreElements()) {
                    Instance inst = (Instance) dataEnum.nextElement();
                    if (inst.isMissing(splitAttribute)) {
                        inst.setValue(splitAttribute, splitAttribute.value(index));
                    }
                }
            }

            // Membuat tree baru di bawah node ini
            Instances[] splitData;
            if (splitAttribute.isNumeric()) {
                splitData = splitData(data, splitAttribute, splitThreshold);
                children = new NewJ48[2];
                for (int j = 0; j < 2; j++) {
                    children[j] = new NewJ48();
                    children[j].makeTree(splitData[j]);
                }
            } else {
                splitData = splitData(data, splitAttribute);
                children = new NewJ48[splitAttribute.numValues()];
                for (int j = 0; j < splitAttribute.numValues(); j++) {
                    children[j] = new NewJ48();
                    children[j].makeTree(splitData[j]);
                }
            }
            isLeaf = false;
        }
    }
}

From source file:newdtl.NewJ48.java

/**
 * Creates a pruned J48 tree using expected error pruning.
 *
 * @param data the training data//  ww w . j  av a2  s.  com
 */
private double pruneTree(Instances data) throws Exception {

    double staticError = staticErrorEstimate((int) DoubleStream.of(classDistributions).sum(),
            (int) classDistributions[maxIndex(classDistributions)], classDistributions.length);

    if (isLeaf) {
        return staticError;
    } else {
        double backupError = 0;
        double totalInstances = DoubleStream.of(classDistributions).sum();

        for (NewJ48 children1 : children) {
            double totalChildInstances = DoubleStream.of(children1.classDistributions).sum();
            backupError += totalChildInstances / totalInstances * children1.pruneTree(data);
        }

        if (staticError < backupError) {
            splitAttribute = null;
            label = maxIndex(classDistributions);
            classAttribute = data.classAttribute();
            isLeaf = true;
            children = null;

            return staticError;
        } else {
            return backupError;
        }
    }
}

From source file:nlp.NLP.java

public void calculateRate(String review) throws IOException, Exception {
    double positiveSentences = 0, allSentences = 0;
    String predictedClass = "";
    File writeFile = new File("test.arff");
    PrintWriter pw = new PrintWriter(writeFile);
    pw.println("@relation movie_review");
    pw.println("@attribute 'positive_words' numeric");
    pw.println("@attribute 'negative_words' numeric");
    pw.println("@attribute 'positive_score' numeric");
    pw.println("@attribute 'negative_score' numeric");
    pw.println("@attribute 'strongPositive' numeric");
    pw.println("@attribute 'strongNegative' numeric");
    pw.println("@attribute 'subjective_words' numeric");
    pw.println("@attribute 'neutral_words' numeric");
    pw.println("@attribute 'adj_words' numeric");
    pw.println("@attribute 'adv_words' numeric");
    pw.println("@attribute 'class' {negative, positive}");
    pw.println("@data");

    String[] splitByPoint = review.split("\\.");
    for (int j = 0; j < splitByPoint.length; j++) {
        // String normalized = normalization(splitByPoint[j]);

        if (splitByPoint[j] == null || splitByPoint[j].isEmpty()) {
            continue;
        }//from  w w w .  java 2 s .c  o  m
        System.out.println("your review : " + splitByPoint[j]);
        WekaFileGenerator wk = new WekaFileGenerator(splitByPoint[j], pipeline, ra);
        pw.print(wk.getSentence().getPositiveWords() + "," + wk.getSentence().getNegativeWords() + ","
                + +wk.getSentence().getSumOfPositiveScore() + "," + wk.getSentence().getSumOfNegativeScore()
                + "," + wk.getSentence().getStrongPositive() + "," + wk.getSentence().getStrongNegative() + ","
                + wk.getSentence().getSubjectiveWords() + "," + wk.getSentence().getNeutralWords() + ","
                + wk.getSentence().getNumOfAdjective() + "," + wk.getSentence().getNumOfAdverb() + ", ? \n");
        //   System.out.println("here");
        //    }

    }

    pw.close();
    DataSource test = new DataSource("test.arff");
    Instances testData = test.getDataSet();
    testData.setClassIndex(testData.numAttributes() - 1);

    Classifier j = (Classifier) weka.core.SerializationHelper.read("movieReview.model");

    for (int i = 0; i < testData.numInstances(); i++) {
        Instance inst = testData.instance(i);
        double predictNum = j.classifyInstance(inst);
        predictedClass = testData.classAttribute().value((int) predictNum);
        System.out.println("Class Predicted: " + predictedClass);
        if (predictedClass.equals("positive")) {
            positiveSentences++;
            System.out.println("positiveSentences = " + positiveSentences);
            if (splitByPoint[i].contains("story")) {
                story = 1;
            }
            if (splitByPoint[i].contains("direction")) {
                direction = 1;
            }
        } else {
            //   positiveSentences--;
            if (splitByPoint[i].contains("story")) {
                story = 0;
            }
            if (splitByPoint[i].contains("direction")) {
                direction = 0;
            }
        }
        allSentences++;
        System.out.println("allSentences = " + allSentences);
    }
    DecimalFormat format = new DecimalFormat("#0.000");
    rate = (positiveSentences / allSentences) * 100;
    if (rate != NaN) {
        rate = Double.parseDouble(format.format(rate));
        if (rate > 0 && rate <= 10) {
            rate = 0.5;
        } else if (rate > 10 && rate <= 20) {
            rate = 1.0;
        } else if (rate > 20 && rate <= 30) {
            rate = 1.5;
        } else if (rate > 30 && rate <= 40) {
            rate = 2;
        } else if (rate > 40 && rate <= 50) {
            rate = 2.5;
        } else if (rate > 50 && rate <= 60) {
            rate = 3;
        } else if (rate > 60 && rate <= 70) {
            rate = 3.5;
        } else if (rate > 70 && rate <= 80) {
            rate = 4;
        } else if (rate > 80 && rate <= 90) {
            rate = 4.5;
        } else if (rate > 90 && rate <= 100) {
            rate = 5;
        }
    }
    System.out.println("rate: " + rate);

}

From source file:org.mcennis.graphrat.algorithm.machinelearning.ClassifySingleAttribute.java

License:Open Source License

public void execute(Graph g) {
    // construct the queries to be used

    ActorByMode groundMode = (ActorByMode) ActorQueryFactory.newInstance().create("ActorByMode");
    groundMode.buildQuery((String) parameter.get("GroundMode").get(), ".*", false);

    ActorByMode targetMode = (ActorByMode) ActorQueryFactory.newInstance().create("ActorByMode");
    targetMode.buildQuery((String) parameter.get("TargetMode").get(), ".*", false);

    LinkByRelation groundTruth = (LinkByRelation) LinkQueryFactory.newInstance().create("LinkByRelation");
    groundTruth.buildQuery((String) parameter.get("Relation").get(), false);

    // build a list of new artists
    TreeSet<Actor> artists = new TreeSet<Actor>();
    artists.addAll(AlgorithmMacros.filterActor(parameter, g, targetMode.execute(g, artists, null)));

    // collect the instance variables from the properties to be the 

    Property classifierProperty = g.getProperty(
            AlgorithmMacros.getSourceID(parameter, g, (String) parameter.get("ClassifierProperty").get()));
    if (!classifierProperty.getValue().isEmpty()) {
        Classifier classifier = (Classifier) classifierProperty.getValue().get(0);
        Iterator<Actor> users = AlgorithmMacros.filterActor(parameter, g, groundMode, null, null);
        Instances dataSet = null;
        boolean firstEntry = true;
        while (users.hasNext()) {
            TreeSet<Actor> user = new TreeSet<Actor>();
            user.add(users.next());/*from  www.ja va  2 s  . com*/
            Property property = user.first().getProperty(
                    AlgorithmMacros.getSourceID(parameter, g, (String) parameter.get("SourceProperty").get()));
            if (property.getPropertyClass().getName().contentEquals(Instance.class.getName())) {
                List values = property.getValue();
                if (!values.isEmpty()) {
                    // get the existing instance
                    Instance object = (Instance) values.get(0);
                    if (firstEntry) {
                        firstEntry = false;
                        Instances current = object.dataset();
                        FastVector attributes = new FastVector();
                        for (int j = 0; j < current.numAttributes(); ++j) {
                            attributes.addElement(current.attribute(j));
                        }
                        FastVector targetNames = new FastVector();
                        Iterator<Actor> artistIt = targetMode.executeIterator(g, null, null);
                        while (artistIt.hasNext()) {
                            targetNames.addElement(artistIt.next().getID());
                        }
                        Attribute classValue = new Attribute("TargetID", targetNames);
                        attributes.addElement(classValue);
                        dataSet = new Instances("Training", attributes, 1000);
                        dataSet.setClassIndex(dataSet.numAttributes() - 1);
                    }

                    // for every artist, create a temporary artist classifer
                    double[] content = new double[object.numAttributes() + 1];
                    for (int j = 0; j < object.numAttributes() + 1; ++j) {
                        content[j] = object.value(j);
                    }

                    Instance base = new Instance(1.0, content);
                    try {
                        double strength = classifier.classifyInstance(base);
                        if (!Double.isNaN(strength)) {
                            String id = dataSet.classAttribute().value((int) strength);
                            Actor target = g.getActor((String) parameter.get("TargetMode").get(), id);
                            Link link = LinkFactory.newInstance()
                                    .create((String) parameter.get("Relation").get());
                            if ((LinkEnd) parameter.get("LinkEnd").get() == LinkEnd.SOURCE) {
                                link.set(user.first(), strength, target);
                            } else {
                                link.set(target, strength, user.first());
                            }
                            g.add(link);
                        }
                    } catch (Exception ex) {
                        Logger.getLogger(ClassifyPerActor.class.getName()).log(Level.SEVERE, null, ex);
                    }

                }
            }
        }
    }
}

From source file:org.mcennis.graphrat.algorithm.machinelearning.WekaClassifierMultiAttribute.java

License:Open Source License

@Override
public void execute(Graph g) {
    Actor[] source = g.getActor((String) parameter[1].getValue());
    if (source != null) {

        // create the atributes for each artist
        FastVector sourceTypes = new FastVector();
        Actor[] dest = g.getActor((String) parameter[3].getValue());
        if (dest != null) {
            // create the Instances set backing this object
            Instances masterSet = null;
            Instance[] trainingData = new Instance[source.length];
            for (int i = 0; i < source.length; ++i) {
                // First, acquire the instance objects for each actor
                Property p = null;//from ww w.j ava 2 s . com
                if ((Boolean) parameter[10].getValue()) {
                    p = source[i].getProperty((String) parameter[2].getValue() + g.getID());
                } else {
                    p = source[i].getProperty((String) parameter[2].getValue());
                }
                if (p != null) {
                    Object[] values = p.getValue();
                    if (values.length > 0) {
                        sourceTypes.addElement(source[i].getID());
                        trainingData[i] = (Instance) ((Instance) values[0]).copy();
                        // assume that this Instance has a backing dataset 
                        // that contains all Instance objects to be tested
                        if (masterSet == null) {
                            masterSet = new Instances(trainingData[i].dataset(), source.length);
                        }
                        masterSet.add(trainingData[i]);
                        sourceTypes.addElement(source[i].getID());
                    } else {
                        trainingData[i] = null;
                        Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                                "Actor " + source[i].getType() + ":" + source[i].getID()
                                        + " does not have an Instance value of property ID " + p.getType());
                    }
                } else {
                    trainingData[i] = null;
                    Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                            "Actor " + source[i].getType() + ":" + source[i].getID()
                                    + " does not have a property of ID " + p.getType());
                }
            }

            Vector<Attribute> destVector = new Vector<Attribute>();
            for (int i = 0; i < dest.length; ++i) {
                FastVector type = new FastVector();
                type.addElement("false");
                type.addElement("true");
                Attribute tmp = new Attribute(dest[i].getID(), type);
                destVector.add(tmp);
                masterSet.insertAttributeAt(tmp, masterSet.numAttributes());
            }
            Attribute sourceID = new Attribute("sourceID", sourceTypes);
            masterSet.insertAttributeAt(sourceID, masterSet.numAttributes());

            //set ground truth for evaluation
            for (int i = 0; i < masterSet.numInstances(); ++i) {
                Instance inst = masterSet.instance(i);
                Actor user = g.getActor((String) parameter[i].getValue(),
                        sourceID.value((int) inst.value(sourceID)));
                if (user != null) {
                    for (int j = 0; j < dest.length; ++j) {
                        if (g.getLink((String) parameter[4].getValue(), user, dest[j]) != null) {
                            inst.setValue(sourceID, "true");
                        } else {
                            if ((Boolean) parameter[9].getValue()) {
                                inst.setValue(sourceID, "false");
                            } else {
                                inst.setValue(sourceID, Double.NaN);
                            }
                        }
                    }
                } else {
                    Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE,
                            "Actor " + sourceID.value((int) inst.value(sourceID)) + " does not exist in graph");
                }
            }

            // perform cross fold evaluation of each classifier in turn
            String[] opts = ((String) parameter[9].getValue()).split("\\s+");
            Properties props = new Properties();
            if ((Boolean) parameter[11].getValue()) {
                props.setProperty("LinkType", (String) parameter[5].getValue() + g.getID());
            } else {
                props.setProperty("LinkType", (String) parameter[5].getValue());
            }
            props.setProperty("LinkClass", "Basic");
            try {
                for (int destCount = 0; destCount < dest.length; ++destCount) {
                    masterSet.setClass(destVector.get(destCount));
                    for (int i = 0; i < (Integer) parameter[8].getValue(); ++i) {
                        Instances test = masterSet.testCV((Integer) parameter[8].getValue(), i);
                        Instances train = masterSet.testCV((Integer) parameter[8].getValue(), i);
                        Classifier classifier = (Classifier) ((Class) parameter[7].getValue()).newInstance();
                        classifier.setOptions(opts);
                        classifier.buildClassifier(train);
                        for (int j = 0; j < test.numInstances(); ++j) {
                            String sourceName = sourceID.value((int) test.instance(j).value(sourceID));
                            double result = classifier.classifyInstance(test.instance(j));
                            String predicted = masterSet.classAttribute().value((int) result);
                            Link derived = LinkFactory.newInstance().create(props);
                            derived.set(g.getActor((String) parameter[2].getValue(), sourceName), 1.0,
                                    g.getActor((String) parameter[3].getValue(), predicted));
                            g.add(derived);
                        }
                    }
                }
            } catch (InstantiationException ex) {
                Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (IllegalAccessException ex) {
                Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.SEVERE, null, ex);
            }

        } else { // dest==null
            Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                    "Ground truth mode '" + (String) parameter[3].getValue() + "' has no actors");
        }
    } else { // source==null
        Logger.getLogger(WekaClassifierMultiAttribute.class.getName()).log(Level.WARNING,
                "Source mode '" + (String) parameter[2].getValue() + "' has no actors");
    }
}

From source file:org.mcennis.graphrat.algorithm.machinelearning.WekaClassifierOneAttribute.java

License:Open Source License

@Override
public void execute(Graph g) {
    Actor[] source = g.getActor((String) parameter[1].getValue());
    if (source != null) {

        // create the Instance sets for each ac
        FastVector classTypes = new FastVector();
        FastVector sourceTypes = new FastVector();
        Actor[] dest = g.getActor((String) parameter[3].getValue());
        if (dest != null) {
            for (int i = 0; i < dest.length; ++i) {
                classTypes.addElement(dest[i].getID());
            }/*from  w ww  .jav  a2 s .com*/
            Attribute classAttribute = new Attribute((String) parameter[5].getValue(), classTypes);

            Instance[] trainingData = new Instance[source.length];
            Instances masterSet = null;
            for (int i = 0; i < source.length; ++i) {

                // First, acquire the instance objects for each actor
                Property p = null;
                if ((Boolean) parameter[9].getValue()) {
                    p = source[i].getProperty((String) parameter[2].getValue() + g.getID());
                } else {
                    p = source[i].getProperty((String) parameter[2].getValue());
                }
                if (p != null) {
                    Object[] values = p.getValue();
                    if (values.length > 0) {
                        sourceTypes.addElement(source[i].getID());
                        trainingData[i] = (Instance) ((Instance) values[0]).copy();
                        // assume that this Instance has a backing dataset 
                        // that contains all Instance objects to be tested
                        if (masterSet == null) {
                            masterSet = new Instances(trainingData[i].dataset(), source.length);
                        }
                        masterSet.add(trainingData[i]);
                    } else {
                        trainingData[i] = null;
                        Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                                "Actor " + source[i].getType() + ":" + source[i].getID()
                                        + " does not have an Instance value of property ID " + p.getType());
                    }
                } else {
                    trainingData[i] = null;
                    Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                            "Actor " + source[i].getType() + ":" + source[i].getID()
                                    + " does not have a property of ID " + p.getType());
                }

            } // for every actor, fix the instance
            Attribute sourceID = new Attribute("sourceID", sourceTypes);
            masterSet.insertAttributeAt(sourceID, masterSet.numAttributes());
            masterSet.insertAttributeAt(classAttribute, masterSet.numAttributes());
            masterSet.setClass(classAttribute);
            for (int i = 0; i < source.length; ++i) {
                if (trainingData[i] != null) {
                    trainingData[i].setValue(sourceID, source[i].getID());
                    Link[] link = g.getLinkBySource((String) parameter[4].getValue(), source[i]);
                    if (link == null) {
                        trainingData[i].setClassValue(Double.NaN);
                    } else {
                        trainingData[i].setClassValue(link[0].getDestination().getID());
                    }
                }
            }

            String[] opts = ((String) parameter[7].getValue()).split("\\s+");
            Properties props = new Properties();
            if ((Boolean) parameter[10].getValue()) {
                props.setProperty("LinkType", (String) parameter[5].getValue() + g.getID());
            } else {
                props.setProperty("LinkType", (String) parameter[5].getValue());
            }
            props.setProperty("LinkClass", "Basic");
            try {
                for (int i = 0; i < (Integer) parameter[8].getValue(); ++i) {
                    Instances test = masterSet.testCV((Integer) parameter[8].getValue(), i);
                    Instances train = masterSet.testCV((Integer) parameter[8].getValue(), i);
                    Classifier classifier = (Classifier) ((Class) parameter[6].getValue()).newInstance();
                    classifier.setOptions(opts);
                    classifier.buildClassifier(train);
                    for (int j = 0; j < test.numInstances(); ++j) {
                        String sourceName = sourceID.value((int) test.instance(j).value(sourceID));
                        double result = classifier.classifyInstance(test.instance(j));
                        String predicted = masterSet.classAttribute().value((int) result);
                        Link derived = LinkFactory.newInstance().create(props);
                        derived.set(g.getActor((String) parameter[2].getValue(), sourceName), 1.0,
                                g.getActor((String) parameter[3].getValue(), predicted));
                        g.add(derived);
                    }
                }
            } catch (InstantiationException ex) {
                Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (IllegalAccessException ex) {
                Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex);
            } catch (Exception ex) {
                Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.SEVERE, null, ex);
            }

        } else { // dest==null
            Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                    "Ground truth mode '" + (String) parameter[3].getValue() + "' has no actors");
        }
    } else { // source==null
        Logger.getLogger(WekaClassifierOneAttribute.class.getName()).log(Level.WARNING,
                "Source mode '" + (String) parameter[2].getValue() + "' has no actors");
    }
}

From source file:org.openml.webapplication.algorithm.InstancesHelper.java

License:Open Source License

public static int[] classCounts(Instances dataset) {
    int[] count = new int[dataset.classAttribute().numValues()];
    for (int i = 0; i < dataset.numInstances(); ++i) {
        count[(int) dataset.instance(i).classValue()]++;
    }/*from   ww  w  .  j  a v a2s . c  o  m*/
    return count;
}

From source file:org.openml.webapplication.algorithm.InstancesHelper.java

License:Open Source License

public static double[] classRatios(Instances dataset) {
    double[] result = new double[dataset.classAttribute().numValues()];
    int[] count = classCounts(dataset);

    for (int i = 0; i < result.length; ++i) {
        result[i] = count[i] * 1.0 / dataset.numInstances();
    }/*w w w.  ja v  a 2 s .c  om*/

    return result;
}

From source file:org.openml.webapplication.algorithm.InstancesHelper.java

License:Open Source License

@SuppressWarnings("unchecked")
public static void stratify(Instances dataset) {
    int numClasses = dataset.classAttribute().numValues();
    int numInstances = dataset.numInstances();
    double[] classRatios = classRatios(dataset);
    double[] currentRatios = new double[numClasses];
    int[] currentCounts = new int[numClasses];
    List<Instance>[] instancesSorted = new LinkedList[numClasses];

    for (int i = 0; i < numClasses; ++i) {
        instancesSorted[i] = new LinkedList<Instance>();
    }/*from   w  ww . ja  va 2  s  .  c om*/

    // first, sort all instances based on class in different lists
    for (int i = 0; i < numInstances; ++i) {
        Instance current = dataset.instance(i);
        instancesSorted[(int) current.classValue()].add(current);
    }

    // now empty the original dataset, all instances are stored in the L.L.
    for (int i = 0; i < numInstances; i++) {
        dataset.delete(dataset.numInstances() - 1);
    }

    for (int i = 0; i < numInstances; ++i) {
        int idx = biggestDifference(classRatios, currentRatios);
        dataset.add(instancesSorted[idx].remove(0));
        currentCounts[idx]++;

        for (int j = 0; j < currentRatios.length; ++j) {
            currentRatios[j] = (currentCounts[j] * 1.0) / (i + 1);
        }
    }
}