regression.gui.RegressionChart.java Source code

Java tutorial

Introduction

Here is the source code for regression.gui.RegressionChart.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package regression.gui;

import java.awt.Color;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.XYDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.jfree.util.ShapeUtilities;
import regression.Function;
import regression.Point;

/**
 *
 * @author Cyga
 */
public class RegressionChart extends javax.swing.JFrame {

    public RegressionChart(Function function, List<Point> points) {

        final XYDataset dataset = createDataset(function, points);
        final JFreeChart chart = createChart(dataset);
        final ChartPanel chartPanel = new ChartPanel(chart);
        chartPanel.setPreferredSize(new java.awt.Dimension(500, 270));
        setContentPane(chartPanel);

    }

    public RegressionChart(List<Point> countedPoints, List<Point> points, List<Point> finalProbPoints) {

        final XYDataset dataset = createLogisticDataset(countedPoints, points, finalProbPoints);
        final JFreeChart chart = createChart(dataset);
        final ChartPanel chartPanel = new ChartPanel(chart);
        chartPanel.setPreferredSize(new java.awt.Dimension(500, 270));
        setContentPane(chartPanel);

    }

    /**
     * Creates a sample dataset.
     *
     * @return a sample dataset.
     */
    private XYDataset createDataset(Function function, List<Point> points) {
        List<Point> functionPoints = createFunction(function, points);
        final XYSeries series1 = new XYSeries("Funkcja regresji");
        for (Iterator<Point> it = functionPoints.iterator(); it.hasNext();) {
            Point point = it.next();
            series1.add(point.getX(), point.getY());
        }
        final XYSeries series2 = new XYSeries("Punkty Klasy A");
        final XYSeries series3 = new XYSeries("Punkty Klasy B");
        for (Iterator<Point> it = points.iterator(); it.hasNext();) {
            Point point = it.next();
            if (checkIfPointAboveLine(function, point)) {
                series2.add(point.getX(), point.getY());
            } else {
                series3.add(point.getX(), point.getY());
            }

        }

        final XYSeriesCollection dataset = new XYSeriesCollection();
        dataset.addSeries(series2);
        dataset.addSeries(series1);
        dataset.addSeries(series3);
        return dataset;

    }

    /**
     * Creates a sample dataset.
     *
     * @return a sample dataset.
     */
    private XYDataset createLogisticDataset(List<Point> countedPoints, List<Point> points,
            List<Point> finalProbPoints) {

        final XYSeries series1 = new XYSeries("Funkcja regresji");
        final XYSeries series0 = new XYSeries("Funkcja podzialu y=0,5");
        //stara
        for (Iterator<Point> it = finalProbPoints.iterator(); it.hasNext();) {
            Point point = it.next();
            series1.add(point.getX(), point.getY());
        }

        //        final XYSeries series2 = new XYSeries("Punkty Klasy A");
        //         final XYSeries series3 = new XYSeries("Punkty Klasy B");
        int pointToGetFuncton = countedPoints.size() / 2;
        //        for (Iterator<Point> it = points.iterator(); it.hasNext();) {
        //            Point point = it.next();
        //            if(checkIfPointAboveLogistic(countedPoints.get(pointToGetFuncton),point)){
        //            series2.add(point.getX(), point.getY());
        //            }else{
        //            series3.add(point.getX(), point.getY());
        //            }
        //            
        //        }
        //nowa

        final XYSeries series2 = new XYSeries("Punkty Klasy 1");
        final XYSeries series3 = new XYSeries("Punkty Klasy 0");
        int ite = 0;
        for (Iterator<Point> it = points.iterator(); it.hasNext();) {

            Point point = it.next();
            if (countedPoints.get(ite).getY() > 0.5) {
                series2.add(point.getX(), point.getY());
            } else {
                series3.add(point.getX(), point.getY());
            }

            //            if(checkIfPointAboveLogistic(new Point(0.0, 0.5),countedPoints.get(ite))){
            //            series2.add(point.getX(), point.getY());
            //            }else{
            //            series3.add(point.getX(), point.getY());
            //            }
            ite++;

        }
        Collections.sort(points, new PointXComparator());

        Double lastPoint = points.get(points.size() - 1).getX();
        Double logisticPoint = 0.5;
        series0.add(0, logisticPoint);
        series0.add(lastPoint, logisticPoint);
        final XYSeriesCollection dataset = new XYSeriesCollection();
        dataset.addSeries(series2);
        dataset.addSeries(series1);
        dataset.addSeries(series3);
        dataset.addSeries(series0);

        return dataset;

    }

    /**
     * Creates a chart.
     *
     * @param dataset the data for the chart.
     *
     * @return a chart.
     */
    private JFreeChart createChart(final XYDataset dataset) {

        final JFreeChart chart = ChartFactory.createScatterPlot("Wykres funkcji regresji", // chart title
                "X", // x axis label
                "Y", // y axis label
                dataset, // data
                PlotOrientation.VERTICAL, true, // include legend
                true, // tooltips
                false // urls
        );

        chart.setBackgroundPaint(Color.white);

        final XYPlot plot = chart.getXYPlot();
        plot.setBackgroundPaint(Color.lightGray);
        plot.setDomainGridlinePaint(Color.white);
        plot.setRangeGridlinePaint(Color.white);

        final XYLineAndShapeRenderer renderer = new XYLineAndShapeRenderer();
        renderer.setSeriesShape(0, ShapeUtilities.createRegularCross(3, 3));
        renderer.setSeriesShape(2, ShapeUtilities.createRegularCross(3, 3));
        renderer.setSeriesLinesVisible(0, false);

        renderer.setSeriesShapesVisible(1, false);
        renderer.setSeriesLinesVisible(2, false);
        plot.setRenderer(renderer);

        final NumberAxis rangeAxis = (NumberAxis) plot.getRangeAxis();
        rangeAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits());

        return chart;

    }

    public List<Point> createFunction(Function function, List<Point> testingPoints) {
        List<Point> points = new ArrayList<>();
        Collections.sort(testingPoints, new Comparator<Point>() {
            @Override
            public int compare(Point o1, Point o2) {
                return (int) (o1.getX() - o2.getX());
            }
        });
        Point lastPoint = testingPoints.get(testingPoints.size() - 1);
        Double[] xValues = new Double[testingPoints.size()];

        for (int i = 0; i < testingPoints.size(); i++) {
            xValues[i] = testingPoints.get(i).getX();
        }
        xValues[xValues.length - 1] = lastPoint.getX() + (lastPoint.getX() / 2);

        Double y;
        Double x;
        for (int i = 0; i < xValues.length; i++) {
            for (Iterator<Double> it = function.getFactor().iterator(); it.hasNext();) {
                Double xFactor = it.next();
                y = xValues[i] * xFactor + function.getFreeFactor().doubleValue();
                points.add(new Point(xValues[i], y));
            }
        }

        return points;

    }

    private boolean checkIfPointAboveLine(Function function, Point point) {

        if (point.getY() > ((point.getX() * function.getFactor().get(0)) + function.getFreeFactor())) {
            return true;

        } else {
            return false;
        }
    }

    private boolean checkIfPointAboveLogistic(Point pointFromfunction, Point pointToCalsify) {
        Double a;
        a = Math.pow(pointFromfunction.getX(), 1 / pointFromfunction.getY());

        if (pointToCalsify.getY() > logOfX(a, pointToCalsify.getX())) {
            return true;
        } else {
            return false;
        }

    }

    private double logOfX(Double x, Double num) {
        return Math.log(num) / Math.log(x);
    }

}