Java tutorial
/* * Encog(tm) Workbench v3.0 * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2011 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.workbench.tabs.visualize.structure; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.FlowLayout; import java.awt.Paint; import java.awt.Point; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.geom.Point2D; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.swing.BorderFactory; import javax.swing.JButton; import javax.swing.JPanel; import javax.swing.border.Border; import org.apache.commons.collections15.Transformer; import org.encog.ml.MLMethod; import org.encog.neural.flat.FlatNetwork; import org.encog.neural.neat.NEATLink; import org.encog.neural.neat.NEATNetwork; import org.encog.neural.neat.NEATNeuron; import org.encog.neural.networks.BasicNetwork; import org.encog.workbench.WorkBenchError; import org.encog.workbench.tabs.EncogCommonTab; import edu.uci.ics.jung.algorithms.layout.StaticLayout; import edu.uci.ics.jung.graph.Graph; import edu.uci.ics.jung.graph.SparseMultigraph; import edu.uci.ics.jung.graph.util.EdgeType; import edu.uci.ics.jung.visualization.GraphZoomScrollPane; import edu.uci.ics.jung.visualization.Layer; import edu.uci.ics.jung.visualization.VisualizationViewer; import edu.uci.ics.jung.visualization.control.AbstractModalGraphMouse; import edu.uci.ics.jung.visualization.control.CrossoverScalingControl; import edu.uci.ics.jung.visualization.control.DefaultModalGraphMouse; import edu.uci.ics.jung.visualization.control.ScalingControl; import edu.uci.ics.jung.visualization.decorators.ToStringLabeller; import edu.uci.ics.jung.visualization.renderers.Renderer; public class StructureTab extends EncogCommonTab { private VisualizationViewer<DrawnNeuron, DrawnConnection> vv; public StructureTab(MLMethod method) { super(null); // Graph<V, E> where V is the type of the vertices // and E is the type of the edges Graph<DrawnNeuron, DrawnConnection> g = null; if (method instanceof BasicNetwork) { BasicNetwork network = (BasicNetwork) method; g = buildGraph(network.getStructure().getFlat()); } else if (method instanceof NEATNetwork) { NEATNetwork neat = (NEATNetwork) method; g = buildGraph(neat); } if (g == null) { throw new WorkBenchError("Can't visualize network: " + method.getClass().getSimpleName()); } Transformer<DrawnNeuron, Point2D> staticTranformer = new Transformer<DrawnNeuron, Point2D>() { public Point2D transform(DrawnNeuron n) { int x = (int) (n.getX() * 600); int y = (int) (n.getY() * 300); Point2D result = new Point(x + 32, y); return result; } }; Transformer<DrawnNeuron, Paint> vertexPaint = new Transformer<DrawnNeuron, Paint>() { public Paint transform(DrawnNeuron neuron) { switch (neuron.getType()) { case Bias: return Color.yellow; case Input: return Color.white; case Output: return Color.green; case Context: return Color.cyan; default: return Color.red; } } }; Transformer<DrawnConnection, Paint> edgePaint = new Transformer<DrawnConnection, Paint>() { public Paint transform(DrawnConnection connection) { if (connection.isContext()) { return Color.lightGray; } else { return Color.black; } } }; // The Layout<V, E> is parameterized by the vertex and edge types StaticLayout<DrawnNeuron, DrawnConnection> layout = new StaticLayout<DrawnNeuron, DrawnConnection>(g, staticTranformer); layout.setSize(new Dimension(5000, 5000)); // sets the initial size of the space // The BasicVisualizationServer<V,E> is parameterized by the edge types //BasicVisualizationServer<DrawnNeuron, DrawnConnection> vv = new BasicVisualizationServer<DrawnNeuron, DrawnConnection>( // layout); //Dimension d = new Dimension(600,600); vv = new VisualizationViewer<DrawnNeuron, DrawnConnection>(layout); //vv.setPreferredSize(d); //Sets the viewing area size vv.getRenderer().getVertexLabelRenderer().setPosition(Renderer.VertexLabel.Position.CNTR); vv.getRenderContext().setVertexLabelTransformer(new ToStringLabeller()); vv.getRenderContext().setVertexFillPaintTransformer(vertexPaint); vv.getRenderContext().setEdgeDrawPaintTransformer(edgePaint); vv.getRenderContext().setArrowDrawPaintTransformer(edgePaint); vv.getRenderContext().setArrowFillPaintTransformer(edgePaint); vv.setVertexToolTipTransformer(new ToStringLabeller()); vv.setVertexToolTipTransformer(new Transformer<DrawnNeuron, String>() { public String transform(DrawnNeuron edge) { return edge.getToolTip(); } }); vv.setEdgeToolTipTransformer(new Transformer<DrawnConnection, String>() { public String transform(DrawnConnection edge) { return edge.getToolTip(); } }); final GraphZoomScrollPane panel = new GraphZoomScrollPane(vv); this.setLayout(new BorderLayout()); add(panel, BorderLayout.CENTER); final AbstractModalGraphMouse graphMouse = new DefaultModalGraphMouse(); vv.setGraphMouse(graphMouse); vv.addKeyListener(graphMouse.getModeKeyListener()); final ScalingControl scaler = new CrossoverScalingControl(); JButton plus = new JButton("+"); plus.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { scaler.scale(vv, 1.1f, vv.getCenter()); } }); JButton minus = new JButton("-"); minus.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { scaler.scale(vv, 1 / 1.1f, vv.getCenter()); } }); JButton reset = new JButton("reset"); reset.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { vv.getRenderContext().getMultiLayerTransformer().getTransformer(Layer.LAYOUT).setToIdentity(); vv.getRenderContext().getMultiLayerTransformer().getTransformer(Layer.VIEW).setToIdentity(); } }); JPanel controls = new JPanel(); controls.setLayout(new FlowLayout(FlowLayout.LEFT)); controls.add(plus); controls.add(minus); controls.add(reset); Border border = BorderFactory.createEtchedBorder(); controls.setBorder(border); add(controls, BorderLayout.NORTH); } private Graph<DrawnNeuron, DrawnConnection> buildGraph(NEATNetwork neat) { int inputCount = 1; int outputCount = 1; int hiddenCount = 1; int biasCount = 1; List<DrawnNeuron> neurons = new ArrayList<DrawnNeuron>(); Graph<DrawnNeuron, DrawnConnection> result = new SparseMultigraph<DrawnNeuron, DrawnConnection>(); List<DrawnNeuron> connections = new ArrayList<DrawnNeuron>(); Map<NEATNeuron, DrawnNeuron> neuronMap = new HashMap<NEATNeuron, DrawnNeuron>(); // place all the neurons for (NEATNeuron neatNeuron : neat.getNeurons()) { String name = ""; DrawnNeuronType t = DrawnNeuronType.Hidden; switch (neatNeuron.getNeuronType()) { case Bias: t = DrawnNeuronType.Bias; name = "B" + (biasCount++); break; case Input: t = DrawnNeuronType.Input; name = "I" + (inputCount++); break; case Output: t = DrawnNeuronType.Output; name = "O" + (outputCount++); break; case Hidden: t = DrawnNeuronType.Hidden; name = "H" + (hiddenCount++); break; } DrawnNeuron neuron = new DrawnNeuron(t, name, neatNeuron.getSplitX(), neatNeuron.getSplitY()); neurons.add(neuron); neuronMap.put(neatNeuron, neuron); } // place all the connections for (NEATNeuron neatNeuron : neat.getNeurons()) { for (NEATLink neatLink : neatNeuron.getOutputboundLinks()) { DrawnNeuron fromNeuron = neuronMap.get(neatLink.getFromNeuron()); DrawnNeuron toNeuron = neuronMap.get(neatLink.getToNeuron()); DrawnConnection connection = new DrawnConnection(fromNeuron, toNeuron, neatLink.getWeight()); fromNeuron.getOutbound().add(connection); toNeuron.getInbound().add(connection); } } for (DrawnNeuron neuron : neurons) { result.addVertex(neuron); for (DrawnConnection connection : neuron.getOutbound()) { result.addEdge(connection, connection.getFrom(), connection.getTo(), EdgeType.DIRECTED); } } return result; } public Graph<DrawnNeuron, DrawnConnection> buildGraph(FlatNetwork flat) { int inputCount = 1; int outputCount = 1; int hiddenCount = 1; int biasCount = 1; int contextCount = 1; int layerCount = flat.getLayerCounts().length; List<DrawnNeuron> neurons = new ArrayList<DrawnNeuron>(); Graph<DrawnNeuron, DrawnConnection> result = new SparseMultigraph<DrawnNeuron, DrawnConnection>(); List<DrawnNeuron> lastFedNeurons; List<DrawnNeuron> connections = new ArrayList<DrawnNeuron>(); double layerSize = 1.0 / layerCount; int neuronNumber = 1; for (int currentLayer = 0; currentLayer < layerCount; currentLayer++) { lastFedNeurons = new ArrayList<DrawnNeuron>(); double x = (double) (layerCount - currentLayer - 1) / (double) layerCount; int neuronCount = flat.getLayerCounts()[currentLayer]; int feedCount = flat.getLayerFeedCounts()[currentLayer]; for (int currentNeuron = 0; currentNeuron < neuronCount; currentNeuron++) { DrawnNeuronType type; double xOffset = 0; String name = "?"; // not a bias or context if (currentNeuron < feedCount) { if (currentLayer == 0) { type = DrawnNeuronType.Output; name = "O" + (outputCount++); } else if (currentLayer == (layerCount - 1)) { type = DrawnNeuronType.Input; name = "I" + (inputCount++); } else { type = DrawnNeuronType.Hidden; name = "H" + (hiddenCount++); } } // is a bias else if (currentNeuron == feedCount) { type = DrawnNeuronType.Bias; name = "B" + (biasCount++); } // is a context else { type = DrawnNeuronType.Context; name = "C" + (contextCount++); xOffset = layerSize / 4; } double y = (double) currentNeuron / (double) neuronCount; double margin = ((double) (neuronCount - 1) / (double) neuronCount); margin = 1.0 - margin; margin /= 2.0; DrawnNeuron neuron = new DrawnNeuron(type, name, x + xOffset, y + margin); neurons.add(neuron); if (neuron.getType() == DrawnNeuronType.Hidden || neuron.getType() == DrawnNeuronType.Output) { lastFedNeurons.add(neuron); } int toNeuron = 0; int count = connections.size(); for (DrawnNeuron connectTo : connections) { int weightIndex = flat.getLayerIndex()[currentLayer] + (toNeuron * count) + currentNeuron; double w = 0;// this.flat.getWeights()[weightIndex]; DrawnConnection connection = new DrawnConnection(neuron, connectTo, w); neuron.getOutbound().add(connection); neuron.getInbound().add(connection); toNeuron++; } } connections = lastFedNeurons; } for (DrawnNeuron neuron : neurons) { result.addVertex(neuron); for (DrawnConnection connection : neuron.getOutbound()) { result.addEdge(connection, connection.getFrom(), connection.getTo(), EdgeType.DIRECTED); } } // draw context links for (int currentLayer = 0; currentLayer < layerCount; currentLayer++) { if (flat.getContextTargetSize()[currentLayer] > 0) { int count = flat.getContextTargetSize()[currentLayer]; int offset = flat.getContextTargetOffset()[currentLayer]; int source = flat.getLayerIndex()[currentLayer]; for (int i = 0; i < count; i++) { DrawnNeuron n1 = neurons.get(source + i); DrawnNeuron n2 = neurons.get(offset + i); DrawnConnection connection = new DrawnConnection(n1, n2, 0); result.addEdge(connection, connection.getFrom(), connection.getTo(), EdgeType.DIRECTED); connection.setContext(true); } } } return result; } @Override public String getName() { return "Structure: " + this.getEncogObject().getName(); } }