Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package mainGUI; import java.awt.BasicStroke; import java.awt.FlowLayout; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.logging.Level; import java.util.logging.Logger; import javax.swing.SwingWorker; import mnistDataset.MNISTDatasetLoader; import mnistDataset.OnDataitemLoadListener; import neuralNetworkLibrary.*; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.JFreeChart; import org.jfree.chart.plot.PlotOrientation; import org.jfree.chart.plot.XYPlot; import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; /** * * @author Ali */ public class TrainingJFrame extends javax.swing.JFrame { class TrainingPublishable { NeuralNet neuralNet; int trainingDataItemIndex; double MSE; int epochNum; public TrainingPublishable(int trainingDataItemIndex) { this.trainingDataItemIndex = trainingDataItemIndex; neuralNet = null; } public TrainingPublishable(NeuralNet neuralNet, int epochNum, double MSE) { this.neuralNet = neuralNet; this.epochNum = epochNum; this.MSE = MSE; trainingDataItemIndex = -1; } } private ChartPanel chartPanel; private XYSeries series; private JFreeChart chart; private XYSeriesCollection seriesCollection; private TrainingDataSet trainingDataSet; private NeuralNet neuralNet; private TrainingHelper trainingHelper; private LearningParameters learningParameters; int i = 0; int epochNumber = 0; private void initLineChartPanel() { series = new XYSeries("MSE"); seriesCollection = new XYSeriesCollection(series); chart = ChartFactory.createXYLineChart("Training", "Epoch Number", "MSE", seriesCollection, PlotOrientation.VERTICAL, true, true, true); XYLineAndShapeRenderer renderer = new XYLineAndShapeRenderer(); renderer.setBaseShapesVisible(true); renderer.setBaseShapesFilled(true); renderer.setBaseStroke(new BasicStroke(3)); XYPlot plot = (XYPlot) chart.getPlot(); plot.setRenderer(renderer); chartPanel = new ChartPanel(chart); chartPanel.setSize(chartContainerPanel.getWidth(), chartContainerPanel.getHeight()); chartContainerPanel.add(chartPanel); } private void add(int i, double val) { series.add(i, val); } private void updateLearningParameters() { learningParameters.setLearningRate(Double.parseDouble(learningRateText.getText())); learningParameters.setMomentum(Double.parseDouble(momentumText.getText())); learningParameters.setWeightDecay(Double.parseDouble(weightDecayText.getText())); } private void saveNetwork(final NeuralNet net, final String filename) { SwingWorker<Boolean, Void> worker = new SwingWorker<Boolean, Void>() { @Override protected Boolean doInBackground() throws Exception { try { // Serialize data object to a file ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(filename)); out.writeObject(net); out.close(); } catch (IOException e) { e.printStackTrace(); } return true; } @Override protected void done() { System.out.println(filename + " saved"); } }; worker.execute(); System.out.println("Saving file " + filename); } private NeuralNet loadNetwork(String filename) { NeuralNet net = null; try { FileInputStream door = new FileInputStream(filename); ObjectInputStream reader = new ObjectInputStream(door); net = (NeuralNet) reader.readObject(); reader.close(); } catch (Exception e) { e.printStackTrace(); } return net; } private void addMSE(double mse, int epochNumber) { epochNumLabel.setText(epochNumber + ""); lastMSELabel.setText(mse + ""); series.add(epochNumber++, mse); } private void saveNetwork(NeuralNet net, int epochNum) { String netName = netFileNameText.getText() + epochNum; saveNetwork(net, netName); } private void epochFinished(NeuralNet net, int epochNum, double mse) { if (autoSaveNetCheck.isSelected()) { addMSE(mse, epochNum); saveNetwork(net, epochNum); } } private void logHeap() { long maxMemory = Runtime.getRuntime().maxMemory(); System.out.println("maxMemory = " + maxMemory / (1024 * 1024)); long totalMemory = Runtime.getRuntime().totalMemory(); System.out.println("totalMemory = " + totalMemory / (1024 * 1024)); long freeMemory = Runtime.getRuntime().freeMemory(); System.out.println("freeMemory = " + freeMemory / (1024 * 1024)); } /** * Creates new form TrainingJFrame */ public TrainingJFrame() { initComponents(); learningParameters = new LearningParameters(0.001); initLineChartPanel(); } /** * This method is called from within the constructor to initialize the form. * WARNING: Do NOT modify this code. The content of this method is always * regenerated by the Form Editor. */ @SuppressWarnings("unchecked") // <editor-fold defaultstate="collapsed" desc="Generated Code">//GEN-BEGIN:initComponents private void initComponents() { bindingGroup = new org.jdesktop.beansbinding.BindingGroup(); chartContainerPanel = new javax.swing.JPanel(); loadNetPanel = new javax.swing.JPanel(); netFileNameText = new javax.swing.JTextField(); jLabel1 = new javax.swing.JLabel(); autoSaveNetCheck = new javax.swing.JCheckBox(); saveNetButton = new javax.swing.JButton(); loadNetButton = new javax.swing.JButton(); newNetButton = new javax.swing.JButton(); learningParamsPanel = new javax.swing.JPanel(); learningRateText = new javax.swing.JTextField(); jLabel2 = new javax.swing.JLabel(); momentumText = new javax.swing.JTextField(); jLabel3 = new javax.swing.JLabel(); weightDecayText = new javax.swing.JTextField(); jLabel4 = new javax.swing.JLabel(); changeLearningParamsButton = new javax.swing.JButton(); statusPanel = new javax.swing.JPanel(); jLabel5 = new javax.swing.JLabel(); epochNumLabel = new javax.swing.JLabel(); jLabel8 = new javax.swing.JLabel(); lastMSELabel = new javax.swing.JLabel(); epochProgressBar = new javax.swing.JProgressBar(); jLabel10 = new javax.swing.JLabel(); startTrainButton = new javax.swing.JButton(); stopTrainButton = new javax.swing.JButton(); mnistPanel = new javax.swing.JPanel(); loadMNISTButton = new javax.swing.JButton(); mnistProgressBar = new javax.swing.JProgressBar(); jLabel7 = new javax.swing.JLabel(); setDefaultCloseOperation(javax.swing.WindowConstants.EXIT_ON_CLOSE); setTitle("MNIST Training"); chartContainerPanel.setBorder(javax.swing.BorderFactory.createLineBorder(new java.awt.Color(0, 0, 0))); chartContainerPanel.addComponentListener(new java.awt.event.ComponentAdapter() { public void componentResized(java.awt.event.ComponentEvent evt) { chartPanelResizeHandler(evt); } }); javax.swing.GroupLayout chartContainerPanelLayout = new javax.swing.GroupLayout(chartContainerPanel); chartContainerPanel.setLayout(chartContainerPanelLayout); chartContainerPanelLayout.setHorizontalGroup(chartContainerPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING).addGap(0, 0, Short.MAX_VALUE)); chartContainerPanelLayout.setVerticalGroup(chartContainerPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING).addGap(0, 259, Short.MAX_VALUE)); loadNetPanel.setBorder(javax.swing.BorderFactory.createLineBorder(new java.awt.Color(0, 0, 0))); loadNetPanel.setEnabled(false); loadNetPanel.setPreferredSize(new java.awt.Dimension(180, 114)); netFileNameText.setText("nets/net"); org.jdesktop.beansbinding.Binding binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, loadNetPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), netFileNameText, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); netFileNameText.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { netFileNameTextActionPerformed(evt); } }); jLabel1.setText("Network File Path"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, loadNetPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), jLabel1, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); autoSaveNetCheck.setSelected(true); autoSaveNetCheck.setText("Auto Save"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, loadNetPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), autoSaveNetCheck, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); autoSaveNetCheck.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { autoSaveNetCheckActionPerformed(evt); } }); saveNetButton.setText("Save Network"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, loadNetPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), saveNetButton, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); saveNetButton.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { saveNetButtonActionPerformed(evt); } }); loadNetButton.setText("Load Network"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, loadNetPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), loadNetButton, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); loadNetButton.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { loadNetButtonActionPerformed(evt); } }); newNetButton.setText("New Network"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, loadNetPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), newNetButton, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); newNetButton.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { newNetButtonActionPerformed(evt); } }); javax.swing.GroupLayout loadNetPanelLayout = new javax.swing.GroupLayout(loadNetPanel); loadNetPanel.setLayout(loadNetPanelLayout); loadNetPanelLayout.setHorizontalGroup(loadNetPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(loadNetPanelLayout.createSequentialGroup().addContainerGap().addGroup(loadNetPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(loadNetPanelLayout.createSequentialGroup() .addComponent(jLabel1, javax.swing.GroupLayout.PREFERRED_SIZE, 98, javax.swing.GroupLayout.PREFERRED_SIZE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED).addComponent( netFileNameText, javax.swing.GroupLayout.DEFAULT_SIZE, 83, Short.MAX_VALUE)) .addGroup(loadNetPanelLayout.createSequentialGroup().addComponent(saveNetButton) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.UNRELATED) .addComponent(autoSaveNetCheck, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)) .addComponent(loadNetButton, javax.swing.GroupLayout.Alignment.TRAILING, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addComponent(newNetButton, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)) .addContainerGap())); loadNetPanelLayout.setVerticalGroup(loadNetPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(loadNetPanelLayout.createSequentialGroup().addContainerGap() .addGroup(loadNetPanelLayout.createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE) .addComponent(netFileNameText, javax.swing.GroupLayout.PREFERRED_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE) .addComponent(jLabel1)) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addComponent(loadNetButton) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addComponent(newNetButton) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(loadNetPanelLayout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(autoSaveNetCheck).addComponent(saveNetButton)) .addGap(31, 31, 31))); learningParamsPanel.setBorder(javax.swing.BorderFactory.createLineBorder(new java.awt.Color(0, 0, 0))); learningParamsPanel.setEnabled(false); learningRateText.setText("0.001"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, learningParamsPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), learningRateText, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); learningRateText.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { learningRateTextActionPerformed(evt); } }); jLabel2.setText("Learning Rate"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, learningParamsPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), jLabel2, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); momentumText.setText("0"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, learningParamsPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), momentumText, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); jLabel3.setText("Momentum"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, learningParamsPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), jLabel3, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); weightDecayText.setText("0"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, learningParamsPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), weightDecayText, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); weightDecayText.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { weightDecayTextActionPerformed(evt); } }); jLabel4.setText("Weight Decay"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, learningParamsPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), jLabel4, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); changeLearningParamsButton.setText("Change Learning Parameters"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, learningParamsPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), changeLearningParamsButton, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); changeLearningParamsButton.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { changeLearningParamsButtonActionPerformed(evt); } }); javax.swing.GroupLayout learningParamsPanelLayout = new javax.swing.GroupLayout(learningParamsPanel); learningParamsPanel.setLayout(learningParamsPanelLayout); learningParamsPanelLayout.setHorizontalGroup(learningParamsPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(javax.swing.GroupLayout.Alignment.TRAILING, learningParamsPanelLayout .createSequentialGroup().addContainerGap() .addGroup(learningParamsPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.TRAILING) .addComponent(changeLearningParamsButton, javax.swing.GroupLayout.PREFERRED_SIZE, 0, Short.MAX_VALUE) .addGroup(learningParamsPanelLayout.createSequentialGroup() .addGroup(learningParamsPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(learningParamsPanelLayout .createParallelGroup( javax.swing.GroupLayout.Alignment.LEADING, false) .addComponent(jLabel2, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addComponent(jLabel3, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)) .addComponent(jLabel4)) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(learningParamsPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.TRAILING) .addComponent(momentumText, javax.swing.GroupLayout.Alignment.LEADING) .addComponent(learningRateText, javax.swing.GroupLayout.DEFAULT_SIZE, 90, Short.MAX_VALUE) .addComponent(weightDecayText)))) .addContainerGap())); learningParamsPanelLayout.setVerticalGroup( learningParamsPanelLayout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(learningParamsPanelLayout.createSequentialGroup().addContainerGap() .addGroup(learningParamsPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE) .addComponent(learningRateText, javax.swing.GroupLayout.PREFERRED_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE) .addComponent(jLabel2)) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(learningParamsPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE) .addComponent(momentumText, javax.swing.GroupLayout.PREFERRED_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE) .addComponent(jLabel3)) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(learningParamsPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE) .addComponent(weightDecayText, javax.swing.GroupLayout.PREFERRED_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE) .addComponent(jLabel4)) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addComponent(changeLearningParamsButton).addContainerGap())); statusPanel.setBorder(javax.swing.BorderFactory.createLineBorder(new java.awt.Color(0, 0, 0))); statusPanel.setEnabled(false); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, statusPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), statusPanel, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); jLabel5.setText("Epoch Number: "); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, statusPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), jLabel5, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); epochNumLabel.setText("0"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, statusPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), epochNumLabel, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); jLabel8.setText("Last MSE: "); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, statusPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), jLabel8, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); lastMSELabel.setText("0"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, statusPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), lastMSELabel, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); epochProgressBar.setStringPainted(true); jLabel10.setText("Epoch Progress"); binding = org.jdesktop.beansbinding.Bindings.createAutoBinding( org.jdesktop.beansbinding.AutoBinding.UpdateStrategy.READ_WRITE, statusPanel, org.jdesktop.beansbinding.ELProperty.create("${enabled}"), jLabel10, org.jdesktop.beansbinding.BeanProperty.create("enabled")); bindingGroup.addBinding(binding); startTrainButton.setText("Start Training"); startTrainButton.setEnabled(false); startTrainButton.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { startTrainButtonActionPerformed(evt); } }); stopTrainButton.setText("Stop Training"); stopTrainButton.setEnabled(false); stopTrainButton.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { stopTrainButtonActionPerformed(evt); } }); javax.swing.GroupLayout statusPanelLayout = new javax.swing.GroupLayout(statusPanel); statusPanel.setLayout(statusPanelLayout); statusPanelLayout.setHorizontalGroup(statusPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(statusPanelLayout.createSequentialGroup().addContainerGap().addGroup(statusPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(epochProgressBar, javax.swing.GroupLayout.DEFAULT_SIZE, 202, Short.MAX_VALUE) .addGroup(statusPanelLayout.createSequentialGroup().addGroup(statusPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(statusPanelLayout.createSequentialGroup() .addGroup(statusPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(jLabel5).addComponent(jLabel8)) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(statusPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(lastMSELabel).addComponent(epochNumLabel))) .addComponent(jLabel10)).addGap(0, 0, Short.MAX_VALUE)) .addGroup(statusPanelLayout.createSequentialGroup() .addComponent(startTrainButton, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addComponent(stopTrainButton, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE))) .addContainerGap())); statusPanelLayout.setVerticalGroup(statusPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(statusPanelLayout.createSequentialGroup().addContainerGap() .addGroup(statusPanelLayout.createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE) .addComponent(jLabel5).addComponent(epochNumLabel)) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(statusPanelLayout.createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE) .addComponent(jLabel8).addComponent(lastMSELabel)) .addGap(4, 4, 4).addComponent(jLabel10) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addComponent(epochProgressBar, javax.swing.GroupLayout.PREFERRED_SIZE, 19, javax.swing.GroupLayout.PREFERRED_SIZE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.RELATED) .addGroup(statusPanelLayout.createParallelGroup(javax.swing.GroupLayout.Alignment.BASELINE) .addComponent(startTrainButton).addComponent(stopTrainButton)) .addContainerGap(javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE))); mnistPanel.setBorder(javax.swing.BorderFactory.createLineBorder(new java.awt.Color(0, 0, 0))); loadMNISTButton.setText("Load MNIST"); loadMNISTButton.addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { loadMNISTButtonActionPerformed(evt); } }); mnistProgressBar.setStringPainted(true); jLabel7.setText("Progress"); javax.swing.GroupLayout mnistPanelLayout = new javax.swing.GroupLayout(mnistPanel); mnistPanel.setLayout(mnistPanelLayout); mnistPanelLayout.setHorizontalGroup(mnistPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(mnistPanelLayout.createSequentialGroup().addContainerGap().addGroup(mnistPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addComponent(loadMNISTButton, javax.swing.GroupLayout.Alignment.TRAILING, javax.swing.GroupLayout.DEFAULT_SIZE, 112, Short.MAX_VALUE) .addComponent(mnistProgressBar, javax.swing.GroupLayout.PREFERRED_SIZE, 0, Short.MAX_VALUE) .addGroup(mnistPanelLayout.createSequentialGroup().addComponent(jLabel7).addGap(0, 0, Short.MAX_VALUE))) .addContainerGap())); mnistPanelLayout.setVerticalGroup(mnistPanelLayout .createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(mnistPanelLayout.createSequentialGroup().addContainerGap().addComponent(loadMNISTButton) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.UNRELATED).addComponent(jLabel7) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.UNRELATED) .addComponent(mnistProgressBar, javax.swing.GroupLayout.PREFERRED_SIZE, 23, javax.swing.GroupLayout.PREFERRED_SIZE) .addContainerGap(javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE))); javax.swing.GroupLayout layout = new javax.swing.GroupLayout(getContentPane()); getContentPane().setLayout(layout); layout.setHorizontalGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING).addGroup( javax.swing.GroupLayout.Alignment.TRAILING, layout.createSequentialGroup().addContainerGap() .addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.TRAILING) .addComponent(chartContainerPanel, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addGroup(layout.createSequentialGroup() .addComponent(mnistPanel, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.UNRELATED) .addComponent(loadNetPanel, javax.swing.GroupLayout.PREFERRED_SIZE, 207, javax.swing.GroupLayout.PREFERRED_SIZE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.UNRELATED) .addComponent(learningParamsPanel, javax.swing.GroupLayout.PREFERRED_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.PREFERRED_SIZE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.UNRELATED) .addComponent(statusPanel, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE))) .addContainerGap())); layout.setVerticalGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.LEADING) .addGroup(layout.createSequentialGroup().addContainerGap() .addComponent(chartContainerPanel, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addPreferredGap(javax.swing.LayoutStyle.ComponentPlacement.UNRELATED) .addGroup(layout.createParallelGroup(javax.swing.GroupLayout.Alignment.TRAILING, false) .addComponent(learningParamsPanel, javax.swing.GroupLayout.Alignment.LEADING, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addComponent(loadNetPanel, javax.swing.GroupLayout.Alignment.LEADING, javax.swing.GroupLayout.PREFERRED_SIZE, 0, Short.MAX_VALUE) .addComponent(statusPanel, javax.swing.GroupLayout.Alignment.LEADING, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE) .addComponent(mnistPanel, javax.swing.GroupLayout.DEFAULT_SIZE, javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE)) .addContainerGap(javax.swing.GroupLayout.DEFAULT_SIZE, Short.MAX_VALUE))); bindingGroup.bind(); pack(); }// </editor-fold>//GEN-END:initComponents private void autoSaveNetCheckActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_autoSaveNetCheckActionPerformed // TODO add your handling code here: }//GEN-LAST:event_autoSaveNetCheckActionPerformed private void netFileNameTextActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_netFileNameTextActionPerformed // TODO add your handling code here: }//GEN-LAST:event_netFileNameTextActionPerformed private void saveNetButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_saveNetButtonActionPerformed saveNetwork(neuralNet, trainingHelper.getEpoch()); }//GEN-LAST:event_saveNetButtonActionPerformed private void learningRateTextActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_learningRateTextActionPerformed // TODO add your handling code here: }//GEN-LAST:event_learningRateTextActionPerformed private void weightDecayTextActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_weightDecayTextActionPerformed // TODO add your handling code here: }//GEN-LAST:event_weightDecayTextActionPerformed private void startTrainButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_startTrainButtonActionPerformed series.clear(); stopTrainButton.setEnabled(true); startTrainButton.setEnabled(false); trainingHelper = new TrainingHelper(neuralNet, trainingDataSet, trainingDataSet.subDataSet(0, 700), learningParameters); epochProgressBar.setValue(0); epochProgressBar.setMaximum(trainingDataSet.size()); SwingWorker<Boolean, TrainingPublishable> worker = new SwingWorker<Boolean, TrainingPublishable>() { @Override protected Boolean doInBackground() throws Exception { trainingHelper.setOnItemTrainListener(new OnItemTrainListener() { @Override public void onItemTrain(int trainingDataItemIndex) { publish(new TrainingPublishable(trainingDataItemIndex)); } }); trainingHelper.setOnEpochFinishListner(new OnEpochFinishListener() { @Override public void onEpochFinish() { publish(new TrainingPublishable(neuralNet.deepClone(), trainingHelper.getEpoch(), trainingHelper.getEpochMSE())); } }); trainingHelper.train(); return true; } @Override protected void done() { System.out.println("Training Done"); } @Override protected void process(List<TrainingPublishable> chunks) { for (TrainingPublishable chunk : chunks) { if (chunk.neuralNet == null) { if (chunk.trainingDataItemIndex != -1) { epochProgressBar.setValue(chunk.trainingDataItemIndex); epochProgressBar.setString(chunk.trainingDataItemIndex + " of 60000"); } } else { System.gc(); System.out.println("Calling GC"); logHeap(); epochFinished(chunk.neuralNet, chunk.epochNum, chunk.MSE); } } } }; worker.execute(); }//GEN-LAST:event_startTrainButtonActionPerformed private void changeLearningParamsButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_changeLearningParamsButtonActionPerformed updateLearningParameters(); }//GEN-LAST:event_changeLearningParamsButtonActionPerformed private void chartPanelResizeHandler(java.awt.event.ComponentEvent evt) {//GEN-FIRST:event_chartPanelResizeHandler chartPanel.setSize(chartContainerPanel.getWidth(), chartContainerPanel.getHeight()); }//GEN-LAST:event_chartPanelResizeHandler private void loadMNISTButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_loadMNISTButtonActionPerformed mnistProgressBar.setMaximum(60000); SwingWorker<TrainingDataSet, Integer> worker = new SwingWorker<TrainingDataSet, Integer>() { @Override protected TrainingDataSet doInBackground() throws Exception { MNISTDatasetLoader.setOnDataitemLoadListener(new OnDataitemLoadListener() { @Override public void onDataitemLoadListener(int currentProgress) { publish(currentProgress); } }); return MNISTDatasetLoader.loadMNISTDataset("MNIST/train-images.idx3-ubyte", "MNIST/train-labels.idx1-ubyte"); } @Override protected void process(List<Integer> chunks) { mnistProgressBar.setValue(chunks.get(chunks.size() - 1)); mnistProgressBar.setString(chunks.get(chunks.size() - 1) + " of 60000"); } protected void done() { try { trainingDataSet = get(); loadNetPanel.setEnabled(true); System.out.println("MNIST Dataset Loaded"); } catch (InterruptedException ex) { ex.printStackTrace(); Logger.getLogger(TrainingJFrame.class.getName()).log(Level.SEVERE, null, ex); } catch (ExecutionException ex) { ex.printStackTrace(); Logger.getLogger(TrainingJFrame.class.getName()).log(Level.SEVERE, null, ex); } } }; worker.execute(); }//GEN-LAST:event_loadMNISTButtonActionPerformed private void newNetButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_newNetButtonActionPerformed neuralNet = new NeuralNet(learningParameters); neuralNet.addLayer(new InputLayer(784)); neuralNet.addLayer(new TanhLayer(330)); neuralNet.addLayer(new TanhLayer(10)); neuralNet.connectAllLayers(); learningParamsPanel.setEnabled(true); System.out.println("Created New Network"); statusPanel.setEnabled(true); startTrainButton.setEnabled(true); }//GEN-LAST:event_newNetButtonActionPerformed private void stopTrainButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_stopTrainButtonActionPerformed trainingHelper.stopTraining(); startTrainButton.setEnabled(true); stopTrainButton.setEnabled(false); }//GEN-LAST:event_stopTrainButtonActionPerformed private void loadNetButtonActionPerformed(java.awt.event.ActionEvent evt) {//GEN-FIRST:event_loadNetButtonActionPerformed String netName = netFileNameText.getText(); neuralNet = loadNetwork(netName); System.out.println("Network Loaded"); learningParamsPanel.setEnabled(true); statusPanel.setEnabled(true); startTrainButton.setEnabled(true); }//GEN-LAST:event_loadNetButtonActionPerformed /** * @param args the command line arguments */ public static void main(String args[]) { /* Set the Nimbus look and feel */ //<editor-fold defaultstate="collapsed" desc=" Look and feel setting code (optional) "> /* If Nimbus (introduced in Java SE 6) is not available, stay with the default look and feel. * For details see http://download.oracle.com/javase/tutorial/uiswing/lookandfeel/plaf.html */ try { for (javax.swing.UIManager.LookAndFeelInfo info : javax.swing.UIManager.getInstalledLookAndFeels()) { if ("Nimbus".equals(info.getName())) { javax.swing.UIManager.setLookAndFeel(info.getClassName()); break; } } } catch (ClassNotFoundException ex) { java.util.logging.Logger.getLogger(TrainingJFrame.class.getName()).log(java.util.logging.Level.SEVERE, null, ex); } catch (InstantiationException ex) { java.util.logging.Logger.getLogger(TrainingJFrame.class.getName()).log(java.util.logging.Level.SEVERE, null, ex); } catch (IllegalAccessException ex) { java.util.logging.Logger.getLogger(TrainingJFrame.class.getName()).log(java.util.logging.Level.SEVERE, null, ex); } catch (javax.swing.UnsupportedLookAndFeelException ex) { java.util.logging.Logger.getLogger(TrainingJFrame.class.getName()).log(java.util.logging.Level.SEVERE, null, ex); } //</editor-fold> /* Create and display the form */ java.awt.EventQueue.invokeLater(new Runnable() { public void run() { new TrainingJFrame().setVisible(true); } }); } // Variables declaration - do not modify//GEN-BEGIN:variables private javax.swing.JCheckBox autoSaveNetCheck; private javax.swing.JButton changeLearningParamsButton; private javax.swing.JPanel chartContainerPanel; private javax.swing.JLabel epochNumLabel; private javax.swing.JProgressBar epochProgressBar; private javax.swing.JLabel jLabel1; private javax.swing.JLabel jLabel10; private javax.swing.JLabel jLabel2; private javax.swing.JLabel jLabel3; private javax.swing.JLabel jLabel4; private javax.swing.JLabel jLabel5; private javax.swing.JLabel jLabel7; private javax.swing.JLabel jLabel8; private javax.swing.JLabel lastMSELabel; private javax.swing.JPanel learningParamsPanel; private javax.swing.JTextField learningRateText; private javax.swing.JButton loadMNISTButton; private javax.swing.JButton loadNetButton; private javax.swing.JPanel loadNetPanel; private javax.swing.JPanel mnistPanel; private javax.swing.JProgressBar mnistProgressBar; private javax.swing.JTextField momentumText; private javax.swing.JTextField netFileNameText; private javax.swing.JButton newNetButton; private javax.swing.JButton saveNetButton; private javax.swing.JButton startTrainButton; private javax.swing.JPanel statusPanel; private javax.swing.JButton stopTrainButton; private javax.swing.JTextField weightDecayText; private org.jdesktop.beansbinding.BindingGroup bindingGroup; // End of variables declaration//GEN-END:variables }