Example usage for weka.core Instance classIndex

List of usage examples for weka.core Instance classIndex

Introduction

In this page you can find the example usage for weka.core Instance classIndex.

Prototype

public int classIndex();

Source Link

Document

Returns the class attribute's index.

Usage

From source file:lu.lippmann.cdb.lab.mds.MDSViewBuilder.java

License:Open Source License

/**
 * //from   w w w . jav  a  2  s  .co m
 */
private static void buildFilteredSeries(final MDSResult mdsResult, final XYPlot xyPlot,
        final String... attrNameToUseAsPointTitle) throws Exception {

    final CollapsedInstances distMdsRes = mdsResult.getCInstances();
    final Instances instances = distMdsRes.getInstances();

    final SimpleMatrix coordinates = mdsResult.getCoordinates();

    final Instances collapsedInstances = mdsResult.getCollapsedInstances();
    int maxSize = 0;
    if (distMdsRes.isCollapsed()) {
        final List<Instances> clusters = distMdsRes.getCentroidMap().getClusters();
        final int nbCentroids = clusters.size();
        maxSize = clusters.get(0).size();
        for (int i = 1; i < nbCentroids; i++) {
            final int currentSize = clusters.get(i).size();
            if (currentSize > maxSize) {
                maxSize = currentSize;
            }
        }
    }

    Attribute clsAttribute = null;
    int nbClass = 1;
    if (instances.classIndex() != -1) {
        clsAttribute = instances.classAttribute();
        nbClass = clsAttribute.numValues();
    }
    final XYSeriesCollection dataset = (XYSeriesCollection) xyPlot.getDataset();
    final int fMaxSize = maxSize;

    final List<XYSeries> lseries = new ArrayList<XYSeries>();

    //No class : add one dummy serie
    if (nbClass <= 1) {
        lseries.add(new XYSeries("Serie #1", false));
    } else {
        //Some class : add one serie per class
        for (int i = 0; i < nbClass; i++) {
            lseries.add(new XYSeries(clsAttribute.value(i), false));
        }
    }
    dataset.removeAllSeries();

    /**
     * Initialize filtered series
     */
    final List<Instances> filteredInstances = new ArrayList<Instances>();
    for (int i = 0; i < lseries.size(); i++) {
        filteredInstances.add(new Instances(collapsedInstances, 0));
    }

    final Map<Tuple<Integer, Integer>, Integer> correspondanceMap = new HashMap<Tuple<Integer, Integer>, Integer>();
    for (int i = 0; i < collapsedInstances.numInstances(); i++) {
        final Instance oInst = collapsedInstances.instance(i);
        int indexOfSerie = 0;
        if (oInst.classIndex() != -1) {
            if (distMdsRes.isCollapsed()) {
                indexOfSerie = getStrongestClass(i, distMdsRes);
            } else {
                indexOfSerie = (int) oInst.value(oInst.classAttribute());
            }
        }
        lseries.get(indexOfSerie).add(coordinates.get(i, 0), coordinates.get(i, 1));

        filteredInstances.get(indexOfSerie).add(oInst);
        if (distMdsRes.isCollapsed()) {
            correspondanceMap.put(new Tuple<Integer, Integer>(indexOfSerie,
                    filteredInstances.get(indexOfSerie).numInstances() - 1), i);
        }
    }

    final List<Paint> colors = new ArrayList<Paint>();

    for (final XYSeries series : lseries) {
        dataset.addSeries(series);
    }

    if (distMdsRes.isCollapsed()) {
        final XYLineAndShapeRenderer xyRenderer = new XYLineAndShapeRenderer(false, true) {
            private static final long serialVersionUID = -6019883886470934528L;

            @Override
            public void drawItem(Graphics2D g2, XYItemRendererState state, java.awt.geom.Rectangle2D dataArea,
                    PlotRenderingInfo info, XYPlot plot, ValueAxis domainAxis, ValueAxis rangeAxis,
                    XYDataset dataset, int series, int item, CrosshairState crosshairState, int pass) {

                if (distMdsRes.isCollapsed()) {

                    final Integer centroidIndex = correspondanceMap
                            .get(new Tuple<Integer, Integer>(series, item));
                    final Instances cluster = distMdsRes.getCentroidMap().getClusters().get(centroidIndex);
                    int size = cluster.size();

                    final int shapeSize = (int) (MAX_POINT_SIZE * size / fMaxSize + 1);

                    final double x1 = plot.getDataset().getX(series, item).doubleValue();
                    final double y1 = plot.getDataset().getY(series, item).doubleValue();

                    Map<Object, Integer> mapRepartition = new HashMap<Object, Integer>();
                    mapRepartition.put("No class", size);
                    if (cluster.classIndex() != -1) {
                        mapRepartition = WekaDataStatsUtil.getClassRepartition(cluster);
                    }

                    final RectangleEdge xAxisLocation = plot.getDomainAxisEdge();
                    final RectangleEdge yAxisLocation = plot.getRangeAxisEdge();
                    final double fx = domainAxis.valueToJava2D(x1, dataArea, xAxisLocation);
                    final double fy = rangeAxis.valueToJava2D(y1, dataArea, yAxisLocation);

                    setSeriesShape(series,
                            new Ellipse2D.Double(-shapeSize / 2, -shapeSize / 2, shapeSize, shapeSize));

                    super.drawItem(g2, state, dataArea, info, plot, domainAxis, rangeAxis, dataset, series,
                            item, crosshairState, pass);

                    //Draw pie
                    if (ENABLE_PIE_SHART) {
                        createPieChart(g2, (int) (fx - shapeSize / 2), (int) (fy - shapeSize / 2), shapeSize,
                                mapRepartition, size, colors);
                    }

                } else {

                    super.drawItem(g2, state, dataArea, info, plot, domainAxis, rangeAxis, dataset, series,
                            item, crosshairState, pass);

                }

            }

        };

        xyPlot.setRenderer(xyRenderer);
    }

    final XYToolTipGenerator gen = new XYToolTipGenerator() {
        @Override
        public String generateToolTip(XYDataset dataset, int series, int item) {
            if (distMdsRes.isCollapsed()) {
                final StringBuilder res = new StringBuilder("<html>");
                final Integer centroidIndex = correspondanceMap.get(new Tuple<Integer, Integer>(series, item));
                final Instance centroid = distMdsRes.getCentroidMap().getCentroids().get(centroidIndex);
                final Instances cluster = distMdsRes.getCentroidMap().getClusters().get(centroidIndex);

                //Set same class index for cluster than for original instances
                //System.out.println("Cluster index = "  + cluster.classIndex() + "/" + instances.classIndex());
                cluster.setClassIndex(instances.classIndex());

                Map<Object, Integer> mapRepartition = new HashMap<Object, Integer>();
                mapRepartition.put("No class", cluster.size());
                if (cluster.classIndex() != -1) {
                    mapRepartition = WekaDataStatsUtil.getClassRepartition(cluster);
                }
                res.append(InstanceFormatter.htmlFormat(centroid, false)).append("<br/>");
                for (final Map.Entry<Object, Integer> entry : mapRepartition.entrySet()) {
                    if (entry.getValue() != 0) {
                        res.append("Class :<b>'" + StringEscapeUtils.escapeHtml(entry.getKey().toString())
                                + "</b>' -> " + entry.getValue()).append("<br/>");
                    }
                }
                res.append("</html>");
                return res.toString();
            } else {
                //return InstanceFormatter.htmlFormat(filteredInstances.get(series).instance(item),true);
                return InstanceFormatter.shortHtmlFormat(filteredInstances.get(series).instance(item));
            }
        }
    };

    final Shape shape = new Ellipse2D.Float(0f, 0f, MAX_POINT_SIZE, MAX_POINT_SIZE);

    ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setUseOutlinePaint(true);

    for (int p = 0; p < nbClass; p++) {
        xyPlot.getRenderer().setSeriesToolTipGenerator(p, gen);
        ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setLegendShape(p, shape);
        xyPlot.getRenderer().setSeriesOutlinePaint(p, Color.BLACK);
    }

    for (int ii = 0; ii < nbClass; ii++) {
        colors.add(xyPlot.getRenderer().getItemPaint(ii, 0));
    }

    if (attrNameToUseAsPointTitle.length > 0) {
        final Attribute attrToUseAsPointTitle = instances.attribute(attrNameToUseAsPointTitle[0]);
        if (attrToUseAsPointTitle != null) {
            final XYItemLabelGenerator lg = new XYItemLabelGenerator() {
                @Override
                public String generateLabel(final XYDataset dataset, final int series, final int item) {
                    return filteredInstances.get(series).instance(item).stringValue(attrToUseAsPointTitle);
                }
            };
            xyPlot.getRenderer().setBaseItemLabelGenerator(lg);
            xyPlot.getRenderer().setBaseItemLabelsVisible(true);
        }
    }
}

From source file:lu.lippmann.cdb.lab.mds.UniversalMDS.java

License:Open Source License

public JXPanel buildMDSViewFromDataSet(Instances ds, MDSTypeEnum type) throws Exception {

    final XYSeriesCollection dataset = new XYSeriesCollection();

    final JFreeChart chart = ChartFactory.createScatterPlot("", // title 
            "X", "Y", // axis labels 
            dataset, // dataset 
            PlotOrientation.VERTICAL, true, // legend? yes 
            true, // tooltips? yes 
            false // URLs? no 
    );/*from   w  ww  .  j a va 2  s .  c  o m*/

    final XYPlot xyPlot = (XYPlot) chart.getPlot();

    chart.setTitle(type.name() + " MDS");

    Attribute clsAttribute = null;
    int nbClass = 1;
    if (ds.classIndex() != -1) {
        clsAttribute = ds.classAttribute();
        nbClass = clsAttribute.numValues();
    }

    final List<XYSeries> lseries = new ArrayList<XYSeries>();
    if (nbClass <= 1) {
        lseries.add(new XYSeries("Serie #1", false));
    } else {
        for (int i = 0; i < nbClass; i++) {
            lseries.add(new XYSeries(clsAttribute.value(i), false));
        }
    }
    dataset.removeAllSeries();

    /**
     * Initialize filtered series
     */
    final List<Instances> filteredInstances = new ArrayList<Instances>();
    for (int i = 0; i < lseries.size(); i++) {
        filteredInstances.add(new Instances(ds, 0));
    }

    for (int i = 0; i < ds.numInstances(); i++) {
        final Instance oInst = ds.instance(i);
        int indexOfSerie = 0;
        if (oInst.classIndex() != -1) {
            indexOfSerie = (int) oInst.value(oInst.classAttribute());
        }
        lseries.get(indexOfSerie).add(coordinates[i][0], coordinates[i][1]);
        filteredInstances.get(indexOfSerie).add(oInst);
    }

    final List<Paint> colors = new ArrayList<Paint>();

    for (final XYSeries series : lseries) {
        dataset.addSeries(series);
    }

    final XYToolTipGenerator gen = new XYToolTipGenerator() {
        @Override
        public String generateToolTip(XYDataset dataset, int series, int item) {
            return InstanceFormatter.htmlFormat(filteredInstances.get(series).instance(item), true);
        }
    };

    final Shape shape = new Ellipse2D.Float(0f, 0f, 5f, 5f);

    ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setUseOutlinePaint(true);

    for (int p = 0; p < nbClass; p++) {
        xyPlot.getRenderer().setSeriesToolTipGenerator(p, gen);
        ((XYLineAndShapeRenderer) xyPlot.getRenderer()).setLegendShape(p, shape);
        xyPlot.getRenderer().setSeriesOutlinePaint(p, Color.BLACK);
    }

    for (int ii = 0; ii < nbClass; ii++) {
        colors.add(xyPlot.getRenderer().getItemPaint(ii, 0));
    }

    final ChartPanel chartPanel = new ChartPanel(chart);
    chartPanel.setMouseWheelEnabled(true);
    chartPanel.setPreferredSize(new Dimension(1200, 900));
    chartPanel.setBorder(new TitledBorder("MDS Projection"));
    chartPanel.setBackground(Color.WHITE);

    final JXPanel allPanel = new JXPanel();
    allPanel.setLayout(new BorderLayout());
    allPanel.add(chartPanel, BorderLayout.CENTER);

    return allPanel;
}

From source file:machinelearninglabs.OENaiveBayesClassifier.java

@Override
public double[] distributionForInstance(Instance instance) throws Exception {
    // return an array with a size of the number of classes
    double[] jointProbabilities = new double[instance.attribute(instance.classIndex()).numValues()];
    double[] result = new double[instance.attribute(instance.classIndex()).numValues()];

    // calculate un-normaalized probs
    for (int cls = 0; cls < jointProbabilities.length; cls++) {
        double p = classProbs[cls];
        for (int att = 0; att < instance.numAttributes() - 1; att++) {
            int value = (int) instance.value(att);
            p *= conditionalProbabilities[att][cls][value];
        }/* ww w.  jav  a  2  s . c  om*/
        jointProbabilities[cls] = p;
    }

    // Find normalized probabilities
    for (int i = 0; i < jointProbabilities.length; i++) {
        double denominator = 0;
        for (int j = 0; j < jointProbabilities.length; j++) {
            denominator += jointProbabilities[j];
        }
        result[i] = jointProbabilities[i] / denominator;
    }
    return result;
}

From source file:machinelearninglabs.OENaiveBayesClassifier.java

/***************************** HELPER METHODS  ************************************/

public void classDistribution(Instances data) {

    classCount = new int[data.firstInstance().numClasses()];
    classProbs = new double[data.firstInstance().numClasses()];

    // Get the frequency/count of each class in the data
    for (Instance eachInstance : data) {
        double classValue = eachInstance.value(eachInstance.classIndex());
        classCount[(int) classValue]++;
    }/*ww  w.  ja v a 2s.co m*/

    // Get the probability of the occurence of each class
    for (int i = 0; i < classProbs.length; i++) {
        classProbs[i] = (double) classCount[i] / data.numInstances();
    }

    printIntArray(classCount);
    System.out.println(data.firstInstance().value(0));
    printDoubleArray(classProbs);
}

From source file:machinelearninglabs.OENaiveBayesClassifier.java

public int[][] attributeCounts(Instances data, int att) {
    int numberOfPossibleValuesForAttribute = data.firstInstance().attribute(att).numValues();
    int[][] result = new int[data.numClasses()][numberOfPossibleValuesForAttribute];

    // for each class
    for (Instance eachInstance : data) {
        double classValue = eachInstance.value(eachInstance.classIndex());
        result[(int) classValue][(int) eachInstance.value(att)]++;
    }/*from  w  ww  .  j ava2s.c  o m*/
    //printIntMatrix(result);
    return result;
}

From source file:machinelearninglabs.OENaiveBayesClassifier.java

public double[][] attributeProbs(Instances data, int att) {
    int numberOfPossibleValuesForAttribute = data.firstInstance().attribute(att).numValues();
    double[][] result = new double[data.numClasses()][numberOfPossibleValuesForAttribute];

    // for each class
    for (Instance eachInstance : data) {
        double classValue = eachInstance.value(eachInstance.classIndex());
        result[(int) classValue][(int) eachInstance.value(att)]++;
    }//from  w  ww .  ja  v a  2 s.  c o  m

    // Get conditional probabilities ie probability that attribute = x given some class
    for (int i = 0; i < result.length; i++) {
        for (int j = 0; j < result[i].length; j++) {
            result[i][j] = (double) result[i][j] / classCount[i];
        }
    }
    //printDoubleMatrix(result);
    return result;
}

From source file:machinelearningproject.DecisionTree.java

@Override
public double classifyInstance(Instance instance) throws Exception {
    String classification = mainTree.traverseTree(instance);
    double result = 0.0;
    for (int i = 0; i < instance.numClasses(); i++) {
        if (classification.equals(instance.attribute(instance.classIndex()).value(i))) {
            result = (double) i;
        }/*from   ww w.  ja v a2s .c  om*/
    }
    return result;
}

From source file:machinelearningproject.RandomForest.java

@Override
public double classifyInstance(Instance instance) throws Exception {
    HashMap<String, Integer> classMap = new HashMap<>();

    for (int i = 0; i < dtrees.size(); i++) {
        String key = dtrees.get(i).traverseTree(instance);
        if (classMap.isEmpty() || !classMap.containsKey(key)) {
            classMap.put(key, 1);/*from  w ww.j a va  2s .  c  om*/
        } else {
            if (classMap.containsKey(key)) {
                classMap.put(key, classMap.get(key) + 1);
            }
        }
    }
    Iterator<String> keySetIterator = classMap.keySet().iterator();
    String modeClass = "";
    int count = 0;
    while (keySetIterator.hasNext()) {
        String key = keySetIterator.next();
        if (count < classMap.get(key)) {
            modeClass = key;
            count = classMap.get(key);
        }
    }

    double result = 0.0;
    for (int i = 0; i < instance.numClasses(); i++) {
        if (modeClass.equals(instance.attribute(instance.classIndex()).value(i))) {
            result = (double) i;
        }
    }

    return result;
}

From source file:meka.classifiers.multilabel.BRq.java

License:Open Source License

@Override
public double[] distributionForInstance(Instance instance) throws Exception {

    int c = instance.classIndex();

    double result[] = new double[c];

    Instance finstances[] = convertInstance(instance, c);

    for (int i = 0; i < c; i++) {
        result[i] = m_MultiClassifiers[i].classifyInstance(finstances[i]);
        //result[i] = m_MultiClassifiers[i].distributionForInstance(finstances[i])[1];
    }//from w  w  w . j a  v a2  s  .  com

    return result;
}

From source file:meka.classifiers.multilabel.cc.CNode.java

License:Open Source License

/**
 * Transform - turn [y1,y2,y3,x1,x2] into [y1,y2,x1,x2].
 * @return transformed Instance/*from  w  w w. ja  v a  2s  .c  o  m*/
 */
public Instance transform(Instance x, double ypred[]) throws Exception {
    x = (Instance) x.copy();
    int L = x.classIndex();
    int L_c = (paY.length + 1);
    x.setDataset(null);
    for (int j = 0; j < (L - L_c); j++) {
        x.deleteAttributeAt(0);
    }
    for (int pa : paY) {
        //System.out.println("x_["+map[pa]+"] <- "+ypred[pa]);
        x.setValue(map[pa], ypred[pa]);
    }
    x.setDataset(T);
    x.setClassMissing();
    return x;
}