Java tutorial
/** * Copyright 2012, Wisdom Omuya. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.deafgoat.ml.prognosticator; // Java import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Map.Entry; // Log4j import org.apache.log4j.Logger; // JFreeChart import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartFrame; import org.jfree.chart.ChartUtilities; import org.jfree.chart.JFreeChart; import org.jfree.chart.plot.PlotOrientation; import org.jfree.data.category.DefaultCategoryDataset; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; import org.jfree.ui.RefineryUtilities; /** * Create graphs from prediction files. */ public class Charter { /** * Creates data set containing categorical attributes along with prediction * confidence * * @param files * List of files containing predictions to chart * @return the series collection to chart */ private DefaultCategoryDataset createCategoricalDataset(String[] files) { _logger.info("Collating data"); BufferedReader br = null; // final XYSeriesCollection dataset = new XYSeriesCollection(); final DefaultCategoryDataset dataset = new DefaultCategoryDataset(); XYSeries prediction = null; for (String dataFile : files) { try { String sCurrentLine; prediction = new XYSeries(dataFile); br = new BufferedReader(new FileReader(dataFile)); HashMap<String, Double> avgConfidence = new HashMap<String, Double>(); HashMap<String, Integer> valueCount = new HashMap<String, Integer>(); while ((sCurrentLine = br.readLine()) != null) { String[] data = sCurrentLine.split("\t"); try { if (avgConfidence.containsKey(data[0])) { avgConfidence.put(data[0], avgConfidence.get(data[0]) + Double.parseDouble(data[1])); valueCount.put(data[0], valueCount.get(data[0]) + 1); } else { avgConfidence.put(data[0], Double.parseDouble(data[1])); valueCount.put(data[0], 1); } } catch (NumberFormatException e) { continue; } } for (Entry<String, Double> entry : avgConfidence.entrySet()) { dataset.addValue(entry.getValue() / valueCount.get(entry.getKey()), entry.getKey(), dataFile); } } catch (IOException e) { _logger.error(e.toString()); } finally { try { if (br != null) { br.close(); } } catch (IOException e) { _logger.error(e.toString()); } } if (prediction != null) { // dataset.addSeries(prediction); } } return dataset; } /** * Creates data set containing numeric attributes along with prediction * confidence * * @param files * List of files containing predictions to chart * @return the series collection to chart */ private XYSeriesCollection createNumericDataset(String[] files) { _logger.info("Collating data"); BufferedReader br = null; XYSeries prediction = null; final XYSeriesCollection dataset = new XYSeriesCollection(); for (String dataFile : files) { try { String sCurrentLine; br = new BufferedReader(new FileReader(dataFile)); prediction = new XYSeries(dataFile); while ((sCurrentLine = br.readLine()) != null) { String[] data = sCurrentLine.split("\t"); try { prediction.add(Double.parseDouble(data[1]), Double.parseDouble(data[2])); } catch (NumberFormatException e) { continue; } } } catch (IOException e) { _logger.error(e.toString()); } finally { try { if (br != null) { br.close(); } } catch (IOException e) { _logger.error(e.toString()); } } if (prediction != null) { dataset.addSeries(prediction); } } return dataset; } /** * Shows the given chart * * @param name * The name to save the chart as * @param chart * The chart to draw */ private void drawChart(String name, JFreeChart chart) { _logger.info("Plotting p.d. chart for " + name); ChartFrame frame = new ChartFrame(name, chart); RefineryUtilities.centerFrameOnScreen(frame); frame.pack(); frame.setVisible(true); } /** * Charts data set containing categorical attributes against with prediction * confidence * * @param files * List of files containing predictions to chart * @throws IOException * If list of files can not be read * @return chart The chart to be drawn */ public JFreeChart getCategoricalChart(String[] files) throws IOException { DefaultCategoryDataset dataset = createCategoricalDataset(files); JFreeChart chart = ChartFactory.createBarChart3D(_chartName, // chart // title "Attribute", // domain axis label "Average Confidence", // range axis label dataset, // data PlotOrientation.VERTICAL, // orientation true, // include legend true, // tooltips? false // URLs? ); return chart; } /** * Charts data set containing numeric attributes against with prediction * confidence * * @param files * List of files containing predictions to chart * @throws IOException * If the list of files can not be read * @return chart The chart to be drawn */ public JFreeChart getNumericChart(String[] files) throws IOException { XYSeriesCollection dataset = createNumericDataset(files); JFreeChart chart = ChartFactory.createScatterPlot(_chartName, // chart // title "Values", // domain axis label "Confidence", // range axis label dataset, // data PlotOrientation.VERTICAL, // orientation true, // include legend true, // tooltips? false // URLs? ); return chart; } /** * Saves a category chart to file * * @param category * The category to chart * @param files * List of files containing predictions to chart * @throws Exception * If chart can not be saved */ public void saveCategorical(String category, String[] files) throws Exception { // read the test ARFF file _experimenter.readARFF("test"); // initialize classifier AppClassifier sc = new AppClassifier(_experimenter.filterData(_experimenter._testSet), _experimenter._testSet, _experimenter._config); sc.errorAnalysis(category); Charter pd = new Charter(category); JFreeChart chart = pd.getCategoricalChart(files); pd.saveChart(category, chart); } /** * Generates and saves confidence chart for nominal attributes * * @param mode * Flag indicating if charts should be drawn on screen as well * @param files * List of files containing predictions to chart * @throws Exception * If chart can not be generate */ public void saveCategoricals(boolean mode, String[] files) throws Exception { // read the test ARFF file _experimenter.readARFF("test"); // initialize classifier AppClassifier sc = new AppClassifier(_experimenter.filterData(_experimenter._testSet), _experimenter._testSet, _experimenter._config); ArrayList<String> categoryList = new ArrayList<String>(); for (Attributes attribute : _experimenter._config._attributes.get(_experimenter._config._dumpFile)) { if (attribute.isInclude() && attribute.getAttributeType().equals("nominal")) { categoryList.add(attribute.getRawAttributeName()); } } String[] categories = categoryList.toArray(new String[categoryList.size()]); Charter pd = null; JFreeChart chart = null; for (String category : categories) { sc.errorAnalysis(category); pd = new Charter(category); chart = pd.getCategoricalChart(files); if (mode) { pd.drawChart(_chartName, chart); Thread.sleep(5000); } pd.saveChart(_chartName, chart); } } /** * Saves the given chart * * @param name * The name to save the chart as * @param chart * The chart to save * @throws IOException * If the chart can not be saved */ private void saveChart(String name, JFreeChart chart) throws IOException { _logger.info("Saving chart for " + name); ChartUtilities.saveChartAsPNG(new File(name + ".png"), chart, 2000, 1500); } /** * Saves a numeric chart to file * * @param numeric * The numeric to chart * @param files * List of files containing predictions to chart * @throws Exception * If chart can not be generate */ public void saveNumeric(String numeric, String[] files) throws Exception { // read the test ARFF file _experimenter.readARFF("test"); // initialize classifier AppClassifier sc = new AppClassifier(_experimenter.filterData(_experimenter._testSet), _experimenter._testSet, _experimenter._config); Charter pd = new Charter(numeric); sc.errorAnalysis(numeric); JFreeChart chart = pd.getNumericChart(files); pd.saveChart(numeric, chart); } /** * Generates and saves confidence charts for numeric attributes * * @param mode * Flag indicating if charts should be drawn on screen * @param files * List of files containing predictions to chart * @throws Exception */ public void saveNumerics(boolean mode, String[] files) throws Exception { // read the test ARFF file _experimenter.readARFF("test"); // initialize classifier AppClassifier sc = new AppClassifier(_experimenter.filterData(_experimenter._testSet), _experimenter._testSet, _experimenter._config); ArrayList<String> numericsList = new ArrayList<String>(); for (Attributes attribute : _experimenter._config._attributes.get(_experimenter._config._dumpFile)) { if (attribute.isInclude() && attribute.getAttributeType().equals("numeric")) { numericsList.add(attribute.getRawAttributeName()); } } String[] numerics = numericsList.toArray(new String[numericsList.size()]); Charter pd = null; JFreeChart chart = null; for (String numeric : numerics) { sc.errorAnalysis(numeric); pd = new Charter(numeric); chart = pd.getNumericChart(files); if (mode) { pd.drawChart(_chartName, chart); Thread.sleep(5000); } pd.saveChart(_chartName, chart); } } /** * handle to chart name */ private String _chartName; /** * handle to experimenter object */ private Experimenter _experimenter; /** * handle to logger object */ private Logger _logger; /** * Public constructor * * @param name * Experimenter handle to get charts with * @throws IOException */ public Charter(Experimenter name) throws IOException { _logger = AppLogger.getLogger(); _experimenter = name; } /** * Private constructor * * @param name * The name of the chart * @throws IOException */ private Charter(String name) throws IOException { _logger = AppLogger.getLogger(); _chartName = name; } }