utils.ChartUtils.java Source code

Java tutorial

Introduction

Here is the source code for utils.ChartUtils.java

Source

/*
 * This file is part of the MLDA.
 *
 * (c)  Jose Maria Moyano Murillo
 *      Eva Lucrecia Gibaja Galindo
 *      Sebastian Ventura Soto <sventura@uco.es>
 *
 * For the full copyright and license information, please view the LICENSE
 * file that was distributed with this source code.
 */

package utils;

import java.awt.Color;
import java.awt.Font;
import java.util.ArrayList;
import java.util.HashMap;
import mulan.data.LabelsPair;
import mulan.data.MultiLabelInstances;
import mulan.data.Statistics;
import mulan.data.UnconditionalChiSquareIdentifier;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.annotations.XYTextAnnotation;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.Marker;
import org.jfree.chart.plot.ValueMarker;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.category.DefaultCategoryDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import static utils.Utils.maxKey;
import weka.core.Instance;
import weka.core.Instances;

/**
 * This class implemnts some utils for charts
 * 
 * @author Jose Maria Moyano Murillo
 */
public class ChartUtils {

    /**
     * Update values of a bar chart
     * 
     * @param labelsByFreq Labels ordered by frequency
     * @param nInstances Number of instances
     * @param cp CategoryPlot
     */
    public static void updateValuesBarChart(ImbalancedFeature[] labelsByFreq, int nInstances, CategoryPlot cp) {
        DefaultCategoryDataset data = new DefaultCategoryDataset();

        double prob;

        labelsByFreq = MetricUtils.sortByFrequency(labelsByFreq);

        double sum = 0.0;
        for (int i = 0; i < labelsByFreq.length; i++) {
            prob = labelsByFreq[i].getAppearances() * 1.0 / nInstances;
            sum += prob;

            data.setValue(prob, labelsByFreq[i].getName(), " ");
        }

        cp.setDataset(data);

        // add mean mark
        sum = sum / labelsByFreq.length;
        Marker start = new ValueMarker(sum);
        start.setPaint(Color.red);
        start.setLabelFont(new Font("SansSerif", Font.BOLD, 12));
        start.setLabel("                        Mean: " + MetricUtils.truncateValue(sum, 3));
        cp.addRangeMarker(start);
    }

    /**
     * Update IR bar chart
     * 
     * @param labelsByFrequency Labels ordered by frequency
     * @param IR Imbalance Ratio values
     * @param cp CategoryPlot
     */
    public static void updateIRBarChart(ImbalancedFeature[] labelsByFrequency, double[] IR, CategoryPlot cp) {
        DefaultCategoryDataset myData = new DefaultCategoryDataset();

        double prob = 0;

        labelsByFrequency = MetricUtils.sortByFrequency(labelsByFrequency);

        double sum = 0.0;
        for (int i = labelsByFrequency.length - 1; i >= 0; i--) {
            prob = IR[i];
            sum += prob;
            myData.setValue(prob, labelsByFrequency[i].getName(), " ");
        }

        cp.setDataset(myData);

        // add mean mark
        sum = sum / labelsByFrequency.length;
        Marker meanMark = new ValueMarker(sum);
        meanMark.setPaint(Color.red);
        meanMark.setLabelFont(new Font("SansSerif", Font.BOLD, 12));
        meanMark.setLabel("                        Mean: " + MetricUtils.truncateValue(sum, 3));
        cp.addRangeMarker(meanMark);

        //Add Imbalance limit mark
        Marker limitMark = new ValueMarker(1.5);
        limitMark.setPaint(Color.black);
        limitMark.setLabelFont(new Font("SansSerif", Font.BOLD, 12));

        if ((sum < 1.3) || (sum > 1.7)) {
            limitMark.setLabel("                                                Imbalance limit (IR=1.5)");
        }
        cp.addRangeMarker(limitMark);
    }

    /**
     * Update XY chart
     * 
     * @param plot ChartPanel Plot
     * @param sortedArray Sorted array of values
     */
    public static void updateXYChart(ChartPanel plot, double[] sortedArray) {

        XYPlot xyplot = plot.getChart().getXYPlot();

        double min = sortedArray[0];
        double max = sortedArray[sortedArray.length - 1];

        double median = Utils.getMedian(sortedArray);

        double q1 = Utils.getQ1(sortedArray);
        double q3 = Utils.getQ3(sortedArray);

        XYTextAnnotation annotation;

        //min-lowlimit horizontal
        XYSeries serie15 = new XYSeries("15");
        serie15.add(min, 0.5);

        //max-toplimit horizontal
        XYSeries serie16 = new XYSeries("16");
        serie16.add(max, 0.5);

        //min vertical
        XYSeries serie1 = new XYSeries("0");
        serie1.add(min, 0.45);
        serie1.add(min, 0.55);

        annotation = new XYTextAnnotation("Min", min, 0.40);
        annotation.setFont(new Font("SansSerif", Font.PLAIN, 11));
        xyplot.addAnnotation(annotation);

        //min-q1 horizontal
        XYSeries serie2 = new XYSeries("1");
        serie2.add(min, 0.5);
        serie2.add(q1, 0.5);

        //q1 vertical  
        XYSeries serie3 = new XYSeries("2");
        serie3.add(q1, 0.1);
        serie3.add(q1, 0.9);

        annotation = new XYTextAnnotation("Q1", q1, 0.08);
        annotation.setFont(new Font("SansSerif", Font.PLAIN, 11));
        xyplot.addAnnotation(annotation);

        // median 
        XYSeries serie_mediana = new XYSeries("11");
        serie_mediana.add(median, 0.1);
        serie_mediana.add(median, 0.9);

        annotation = new XYTextAnnotation("Median", median, 0.04);
        annotation.setFont(new Font("SansSerif", Font.PLAIN, 11));
        xyplot.addAnnotation(annotation);

        //q1-q3 horizontal sup
        XYSeries serie4 = new XYSeries("3");
        serie4.add(q1, 0.9);
        serie4.add(q3, 0.9);

        //q1-q3 horizontal inf
        XYSeries serie5 = new XYSeries("4");
        serie5.add(q1, 0.1);
        serie5.add(q3, 0.1);

        //q3 vertical
        XYSeries serie6 = new XYSeries("5");
        serie6.add(q3, 0.1);
        serie6.add(q3, 0.9);

        annotation = new XYTextAnnotation("Q3", q3, 0.08);
        annotation.setFont(new Font("SansSerif", Font.PLAIN, 11));
        xyplot.addAnnotation(annotation);

        //q3-max horizontal
        XYSeries serie7 = new XYSeries("6");
        serie7.add(q3, 0.5);
        serie7.add(max, 0.5);

        //max vertical
        XYSeries serie8 = new XYSeries("7");
        serie8.add(max, 0.45);
        serie8.add(max, 0.55);

        annotation = new XYTextAnnotation("Max", max, 0.4);
        annotation.setFont(new Font("SansSerif", Font.PLAIN, 11));
        xyplot.addAnnotation(annotation);

        XYSeriesCollection xyseriescollection = new XYSeriesCollection();

        xyseriescollection.addSeries(serie1);
        xyseriescollection.addSeries(serie2);
        xyseriescollection.addSeries(serie3);
        xyseriescollection.addSeries(serie4);
        xyseriescollection.addSeries(serie5);
        xyseriescollection.addSeries(serie6);
        xyseriescollection.addSeries(serie7);
        xyseriescollection.addSeries(serie8);
        xyseriescollection.addSeries(serie15);
        xyseriescollection.addSeries(serie16);
        xyseriescollection.addSeries(serie_mediana);

        xyplot.getRenderer().setSeriesPaint(9, Color.black);
        xyplot.getRenderer().setSeriesPaint(10, Color.black);

        xyplot.getRenderer().setSeriesPaint(0, Color.black);
        xyplot.getRenderer().setSeriesPaint(1, Color.black);
        xyplot.getRenderer().setSeriesPaint(2, Color.black);
        xyplot.getRenderer().setSeriesPaint(3, Color.black);
        xyplot.getRenderer().setSeriesPaint(4, Color.black);
        xyplot.getRenderer().setSeriesPaint(5, Color.black);
        xyplot.getRenderer().setSeriesPaint(6, Color.black);
        xyplot.getRenderer().setSeriesPaint(7, Color.black);
        xyplot.getRenderer().setSeriesPaint(8, Color.black);
        xyplot.getRenderer().setSeriesPaint(9, Color.black);
        xyplot.getRenderer().setSeriesPaint(10, Color.black);
        xyplot.getRenderer().setSeriesPaint(11, Color.black);
        xyplot.getRenderer().setSeriesPaint(12, Color.black);
        xyplot.getRenderer().setSeriesPaint(13, Color.black);

        //add dataset
        xyplot.setDataset(xyseriescollection);

        // add a second dataset and renderer... 
        XYSeriesCollection anotherserie = new XYSeriesCollection();

        XYSeries serie_point = new XYSeries("21");

        double[] yValue = { 0.47, 0.49, 0.51, 0.53 };

        for (int i = 0, j = 0; i < sortedArray.length; i++, j++) {
            if (j % 4 == 0) {
                j = 0;
            }
            serie_point.add(sortedArray[i], yValue[j]);
        }

        anotherserie.addSeries(serie_point);

        XYLineAndShapeRenderer renderer1 = new XYLineAndShapeRenderer(false, true);
        renderer1.setSeriesPaint(0, Color.lightGray);

        xyplot.setDataset(1, anotherserie);
        xyplot.setRenderer(1, renderer1);
    }

    /**
     * Update line chart
     * 
     * @param nInstances Number of instances
     * @param cp CategoryPlot
     * @param labelsetsByFrequency Labelsets ordered by frequency
     */
    public static void updateLineChart(int nInstances, CategoryPlot cp,
            HashMap<Integer, Integer> labelsetsByFrequency) {
        DefaultCategoryDataset data = new DefaultCategoryDataset();

        double prob;

        int max = maxKey(labelsetsByFrequency);

        for (int i = 0; i <= max; i++) {
            int currentFreq = 0;
            if (labelsetsByFrequency.get(i) != null) {
                currentFreq = labelsetsByFrequency.get(i);
            }

            prob = currentFreq * 1.0 / nInstances;

            if (prob == 0.0) {
                data.setValue(0, "Label-Combination: ", Integer.toString(i));
            } else {
                data.setValue(prob, "Label-Combination: ", Integer.toString(i));
            }

        }
        cp.setDataset(data);

        if (max > 30) {
            cp.getDomainAxis().setTickLabelsVisible(false);
        } else {
            cp.getDomainAxis().setTickLabelsVisible(true);
        }
    }

    /**
     * Get co-ocurrences matrix
     * 
     * @param dataset Multi-label dataset
     * @return Matrix of doubles with co-ocurrence values
     */
    public static double[][] getCoocurrences(MultiLabelInstances dataset) {
        double[][] coocurrences;
        coocurrences = calculateCoocurrences(dataset);
        return (coocurrences);
    }

    /**
     * Get Chi and Phi coefficients
     * 
     * @param dataset Multi-label dataset
     * @return Matrix of doubles with Chi and Phi values
     */
    public static double[][] getChiPhiCoefficients(MultiLabelInstances dataset) {
        double[][] coefficients = new double[dataset.getNumLabels()][dataset.getNumLabels()];
        double phi, chi;

        try {
            UnconditionalChiSquareIdentifier depid = new UnconditionalChiSquareIdentifier();
            LabelsPair[] pairs = depid.calculateDependence(dataset);
            Statistics stat = new Statistics();
            double[][] phiMatrix = stat.calculatePhi(dataset);

            for (LabelsPair pair : pairs) {
                chi = pair.getScore();
                phi = phiMatrix[pair.getPair()[0]][pair.getPair()[1]];
                coefficients[pair.getPair()[0]][pair.getPair()[1]] = chi;
                coefficients[pair.getPair()[1]][pair.getPair()[0]] = phi;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        return coefficients;
    }

    /**
     * Get vertices of co-ocurrence graph
     * 
     * @param labelName Label name
     * @param list List of attribute pairs
     * @return List of vertices
     */
    public static ArrayList<String> getVertices(String labelName, ArrayList<AttributesPair> list) {
        ArrayList<String> result = new ArrayList<>();

        for (AttributesPair actual : list) {
            if (actual.getAttributeName1().equals(labelName)) {
                result.add(actual.getAttributeName2());
            } else if (actual.getAttributeName2().equals(labelName)) {
                result.add(actual.getAttributeName1());
            }
        }

        return result;
    }

    /**
     * Get border strength for a node in a co-ocurrence graph
     * 
     * @param min Minimum value
     * @param max Maximum value
     * @param n n
     * @param edgeValue Edge value
     * @return Strength
     */
    public static int getBorderStrength(int min, int max, int n, double edgeValue) {
        double interval = (max - min) / (n * 1.0);

        int strength = 0;

        for (double i = min; i < max; i = i + interval) {
            if (edgeValue < i)
                break;
            {
                strength++;
            }
        }
        return strength;
    }

    /**
     * Calculate co-ocurrences of labels
     * 
     * @param mldata Multi-label dataset
     * @return Matrix of double with co-ocurrence values
     */
    public static double[][] calculateCoocurrences(MultiLabelInstances mldata) {
        int nLabels = mldata.getNumLabels();
        Instances data = mldata.getDataSet();

        double[][] coocurrenceMatrix = new double[nLabels][nLabels];

        int[] labelIndices = mldata.getLabelIndices();

        Instance temp = null;
        for (int k = 0; k < data.numInstances(); k++) {
            temp = data.instance(k);

            for (int i = 0; i < nLabels; i++) {
                for (int j = i + 1; j < nLabels; j++) {
                    if ((temp.value(labelIndices[i]) == 1.0) && (temp.value(labelIndices[j]) == 1.0)) {
                        coocurrenceMatrix[i][j]++;
                    }
                }
            }
        }

        return coocurrenceMatrix;
    }
}