Java tutorial
/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.plot; import java.io.*; import java.util.*; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.factory.*; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.PretrainParamInitializer; import org.nd4j.linalg.api.ndarray.INDArray; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.io.ClassPathResource; /** * Credit to : * http://yosinski.com/media/papers/Yosinski2012VisuallyDebuggingRestrictedBoltzmannMachine.pdf * * for visualizations * @author Adam Gibson * */ public class NeuralNetPlotter implements Serializable { private static ClassPathResource script = new ClassPathResource("scripts" + File.separator + "plot.py"); private static final Logger log = LoggerFactory.getLogger(NeuralNetPlotter.class); private static String ID_FOR_SESSION = UUID.randomUUID().toString(); private static String localPath = System.getProperty("java.io.tmpdir") + File.separator; private static String dataFilePath = localPath + "data" + File.separator; private static String graphPath = localPath + "graphs" + File.separator; private static String graphFilePath = graphPath + ID_FOR_SESSION + File.separator; private static String localPlotPath = loadIntoTmp(); private static String layerGraphFilePath = graphFilePath; public String getLayerGraphFilePath() { return layerGraphFilePath; } public void setLayerGraphFilePath(String newPath) { this.layerGraphFilePath = newPath; } public static void printDataFilePath() { log.info("Data stored at " + dataFilePath); } public static void printGraphFilePath() { log.warn("Graphs stored at " + graphFilePath + ". " + "Warning: You must manually delete the folder when you are done."); } private static String loadIntoTmp() { setupDirectory(dataFilePath); setupDirectory(graphFilePath); printDataFilePath(); printGraphFilePath(); File plotPath = new File(graphPath, "plot.py"); plotPath.deleteOnExit(); if (!plotPath.exists()) { try { List<String> lines = IOUtils.readLines(script.getInputStream()); FileUtils.writeLines(plotPath, lines); } catch (IOException e) { throw new IllegalStateException("Unable to load python file"); } } return plotPath.getAbsolutePath(); } protected static void setupDirectory(String path) { File newPath = new File(path); if (!newPath.isDirectory()) newPath.mkdir(); } public void updateGraphDirectory(Layer layer) { String layerType = layer.getClass().toString(); String[] layerPath = layerType.split("\\."); String layerName = Integer.toString(layer.getIndex()) + layerPath[layerPath.length - 1]; String newPath = graphFilePath + File.separator + layerName + File.separator; if (!new File(newPath).exists()) { setupDirectory(newPath); setLayerGraphFilePath(newPath); } } protected String writeMatrix(INDArray matrix) { try { String tmpFilePath = dataFilePath + UUID.randomUUID().toString(); File write = new File(tmpFilePath); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(write, true)); write.deleteOnExit(); for (int i = 0; i < matrix.rows(); i++) { INDArray row = matrix.getRow(i); StringBuilder sb = new StringBuilder(); for (int j = 0; j < row.length(); j++) { sb.append(String.format("%.10f", row.getDouble(j))); if (j < row.length() - 1) sb.append(","); } sb.append("\n"); String line = sb.toString(); bos.write(line.getBytes()); bos.flush(); } bos.close(); return tmpFilePath; } catch (IOException e) { throw new RuntimeException(e); } } public String writeArray(ArrayList data) { try { String tmpFilePath = dataFilePath + UUID.randomUUID().toString(); File write = new File(tmpFilePath); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(write, true)); write.deleteOnExit(); StringBuilder sb = new StringBuilder(); for (Object value : data) { sb.append(String.format("%.10f", (Double) value)); sb.append(","); } String line = sb.toString(); line = line.substring(0, line.length() - 1); bos.write(line.getBytes()); bos.flush(); bos.close(); return tmpFilePath; } catch (IOException e) { throw new RuntimeException(e); } } /** * Calls out to python for rendering charts * @param action the action to take * @param dataPath the path to the data * @param saveFilePath the saved file path for output of graphs */ public void renderGraph(String action, String dataPath, String saveFilePath) { try { log.info("Rendering " + action + " graphs for data analysis..."); Process is = Runtime.getRuntime() .exec("python " + localPlotPath + " " + action + " " + dataPath + " " + saveFilePath); log.info("Std out " + IOUtils.readLines(is.getInputStream()).toString()); log.error("Std error " + IOUtils.readLines(is.getErrorStream()).toString()); } catch (IOException e) { log.warn("Image closed"); throw new RuntimeException(e); } } /** * Calls out to python for rendering charts * @param action the action to take * @param dataPath the path to the data * @param saveFilePath the saved file path for output of graphs * @param feature_width width of feature * @param feature_height height of feature */ public void renderGraph(String action, String dataPath, String saveFilePath, int feature_width, int feature_height) { try { log.info("Rendering " + action + " graphs for data analysis..."); Process is = Runtime.getRuntime().exec("python " + localPlotPath + " " + action + " " + dataPath + " " + saveFilePath + " " + feature_width + " " + feature_height); log.info("Std out " + IOUtils.readLines(is.getInputStream()).toString()); log.error("Std error " + IOUtils.readLines(is.getErrorStream()).toString()); } catch (IOException e) { log.warn("Image closed"); throw new RuntimeException(e); } } /** * graphPlotType sets up data to pass to scripts that render graphs * @param plotType sets which plot to call whether "multi" for multiple matrices histograms, * "hist" for a histogram of one matrix, or "scatter" for scatter plot * @param titles the titles of the plots * @param matrices the matrices to plot */ public void graphPlotType(String plotType, List<String> titles, INDArray[] matrices, String saveFilePath) { String[] path = new String[matrices.length * 2]; if (titles.size() != matrices.length) throw new IllegalArgumentException("Titles and matrix lengths must be equal"); for (int i = 0; i < path.length - 1; i += 2) { path[i] = writeMatrix(matrices[i / 2].ravel()); path[i + 1] = titles.get(i / 2); } String dataPath = StringUtils.join(path, ","); renderGraph(plotType, dataPath, saveFilePath); } /** * plotWeightHistograms graphs values of vBias, W, and hBias on aggregate and * most recent mini-batch updates (-gradient) * @param network the trained neural net model * @param gradient latest updates to weights and biases */ public void plotWeightHistograms(Layer network, Gradient gradient) { Set<String> vars = new TreeSet<>(gradient.gradientForVariable().keySet()); List<String> titles = new ArrayList<>(vars); for (String s : vars) { titles.add(s + "-gradient"); } INDArray[] variablesAndGradients = new INDArray[network.conf().variables().size() * 2]; int count = 0; for (int i = 0; i < network.conf().variables().size(); i++) { String variable = network.conf().variables().get(i); variablesAndGradients[count++] = network.getParam(variable); } for (int i = 0; i < network.conf().variables().size(); i++) { String variable = network.conf().variables().get(i); variablesAndGradients[count++] = gradient.getGradientFor(variable); } graphPlotType("histogram", titles, variablesAndGradients, layerGraphFilePath + "weightHistograms.png"); } public void plotWeightHistograms(Layer network) { plotWeightHistograms(network, network.gradient()); } /** * plotActivations show how hidden neurons are used, how often on vs. off and correlation * @param layer the trained neural net layer **/ public void plotActivations(Layer layer) { if (layer.input() == null) throw new IllegalStateException("Unable to plot; missing input"); // TODO simplify hbiasMean as a sample of the data vs all examples (cut by % if over 40 examples & 100 neurons) INDArray hbiasMean = layer.activationMean(); String dataPath = writeMatrix(hbiasMean); renderGraph("activations", dataPath, layerGraphFilePath + "activationPlot.png"); } /** * renderFilter plot learned filter for each hidden neuron * @param layer the trained neural net layer in the model **/ public void renderFilter(Layer layer, int patchesPerRow) { INDArray weight = layer.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray w = weight.dup(); FilterRenderer render = new FilterRenderer(); try { if (w.shape().length > 2) { INDArray render2 = w.transpose(); render.renderFilters(render2, layerGraphFilePath + "renderFilter.png", w.columns(), w.rows(), w.slices()); } else { render.renderFilters(w, layerGraphFilePath + "renderFilter.png", (int) Math.sqrt(w.rows()), (int) Math.sqrt(w.columns()), patchesPerRow); } // Alternative python approach - work in progress // String dataPath = writeMatrix(w); // renderGraph("filter", dataPath, layerGraphFilePath + "renderFilter.png"); } catch (Exception e) { log.error("Unable to plot filter, continuing...", e); e.printStackTrace(); } } /** * plotNetworkGradient used for debugging RBM gradients with different data visualizations * * @param layer the neural net layer * @param gradient latest updates to weights and biases **/ public void plotNetworkGradient(Layer layer, Gradient gradient) { plotWeightHistograms(layer, gradient); plotActivations(layer); } public void plotNetworkGradient(Layer layer, INDArray gradient) { graphPlotType("histogram", Arrays.asList("W", "w-gradient"), new INDArray[] { layer.getParam(DefaultParamInitializer.WEIGHT_KEY), gradient }, layerGraphFilePath + "weightHistograms.png"); plotActivations(layer); } }