Java tutorial
/* * Copyright (C) 2015 Seoul National University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.snu.dolphin.dnn; import; import; import edu.snu.dolphin.dnn.conf.FullyConnectedLayerConfigurationBuilder; import edu.snu.dolphin.dnn.conf.NeuralNetworkConfigurationBuilder; import edu.snu.dolphin.dnn.layerparam.provider.GroupCommParameterProvider; import edu.snu.dolphin.dnn.layerparam.provider.LocalNeuralNetParameterProvider; import edu.snu.dolphin.dnn.layerparam.provider.ParameterProvider; import edu.snu.dolphin.dnn.proto.NeuralNetworkProtos.*; import edu.snu.dolphin.bsp.parameters.OnLocal; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; import org.apache.reef.tang.Configuration; import org.apache.reef.tang.Tang; import org.apache.reef.tang.annotations.Name; import org.apache.reef.tang.annotations.NamedParameter; import org.apache.reef.tang.annotations.Parameter; import org.apache.reef.tang.formats.CommandLine; import org.apache.reef.tang.formats.ConfigurationSerializer; import javax.inject.Inject; import; import; import; import java.util.List; /** * Class that manages command line parameters specific to the neural network for driver. */ public final class NeuralNetworkDriverParameters { private final String serializedNeuralNetworkConfiguration; private final String delimiter; private final int maxIterations; private final boolean groupComm; private final String inputShape; @NamedParameter(doc = "neural network configuration file path", short_name = "conf") public static class ConfigurationPath implements Name<String> { } @NamedParameter(doc = "delimiter that is used in input file", short_name = "delim", default_value = ",") public static class Delimiter implements Name<String> { } @NamedParameter(doc = "the shape of input data") public static class InputShape implements Name<String> { } /** * Delimiter that is used for distinguishing dimensions of input shape. */ private static final String SHAPE_DELIMITER = ","; /** * Converts a list of integer for an input shape to a string. * @param dimensionList a list of integers for an input shape. * @return a string for an input shape. */ public static String inputShapeToString(final List<Integer> dimensionList) { return StringUtils.join(dimensionList, SHAPE_DELIMITER); } /** * Converts a string for an input shape to an array of integers. * @param inputShapeString a string for an input shape. * @return an array of integers for an input shape. */ public static int[] inputShapeFromString(final String inputShapeString) { final String[] inputShapeStrings = inputShapeString.split(SHAPE_DELIMITER); final int[] inputShape = new int[inputShapeStrings.length]; for (int i = 0; i < inputShapeStrings.length; ++i) { inputShape[i] = Integer.parseInt(inputShapeStrings[i]); } return inputShape; } @Inject private NeuralNetworkDriverParameters(final ConfigurationSerializer configurationSerializer, @Parameter(ConfigurationPath.class) final String configurationPath, @Parameter(Delimiter.class) final String delimiter, @Parameter(MaxIterations.class) final int maxIterations, @Parameter(OnLocal.class) final boolean onLocal) throws IOException { final NeuralNetworkConfiguration neuralNetConf = loadNeuralNetworkConfiguration(configurationPath, onLocal); // the method is being called twice: here and in `buildNeuralNetworkConfiguration` // this could be made to once by refactoring the code this.groupComm = neuralNetConf.getParameterProvider().getType().equals("groupcomm"); this.serializedNeuralNetworkConfiguration = configurationSerializer .toString(buildNeuralNetworkConfiguration(neuralNetConf)); this.delimiter = delimiter; this.maxIterations = maxIterations; // convert to string because Tang configuration serializer does not support List serialization. this.inputShape = inputShapeToString(neuralNetConf.getInputShape().getDimList()); } /** * @param parameterProvider a parameter provider string. * @return the parameter provider class that the given string indicates. */ private static Class<? extends ParameterProvider> getParameterProviderClass(final String parameterProvider) { switch (parameterProvider.toLowerCase()) { case "local": return LocalNeuralNetParameterProvider.class; case "groupcomm": return GroupCommParameterProvider.class; default: throw new IllegalArgumentException("Illegal parameter provider: " + parameterProvider); } } /** * Creates the layer configuration from the given protocol buffer layer configuration message. * @param layerConf the protocol buffer layer configuration message. * @return the layer configuration built from the protocol buffer layer configuration message. */ private static Configuration createLayerConfiguration(final LayerConfiguration layerConf) { switch (layerConf.getType().toLowerCase()) { case "fullyconnected": return FullyConnectedLayerConfigurationBuilder.newConfigurationBuilder() .fromProtoConfiguration(layerConf).build(); default: throw new IllegalArgumentException("Illegal layer type: " + layerConf.getType()); } } /** * Loads the protocol buffer text formatted neural network configuration. * <p/> * Loads the file from the local filesystem or HDFS depending on {@code onLocal}. * @param path the path for the neural network configuration. * @param onLocal the flag for the local runtime environment. * @return the neural network configuration protocol buffer message. * @throws IOException */ private static NeuralNetworkConfiguration loadNeuralNetworkConfiguration(final String path, final boolean onLocal) throws IOException { final NeuralNetworkConfiguration.Builder neuralNetProtoBuilder = NeuralNetworkConfiguration.newBuilder(); // Parses neural network builder protobuf message from the prototxt file. // Reads from the local filesystem. if (onLocal) { TextFormat.merge(new FileReader(path), neuralNetProtoBuilder); // Reads from HDFS. } else { final FileSystem fs = FileSystem.get(new JobConf()); TextFormat.merge(new InputStreamReader( Path(path))), neuralNetProtoBuilder); } return; } /** * Parses the protobuf message and builds neural network configuration. * @param neuralNetConf neural network configuration protobuf message. * @return the neural network configuration. */ private static Configuration buildNeuralNetworkConfiguration(final NeuralNetworkConfiguration neuralNetConf) { final NeuralNetworkConfigurationBuilder neuralNetConfBuilder = NeuralNetworkConfigurationBuilder .newConfigurationBuilder(); neuralNetConfBuilder.setBatchSize(neuralNetConf.getBatchSize()).setStepsize(neuralNetConf.getStepsize()) .setParameterProviderClass( getParameterProviderClass(neuralNetConf.getParameterProvider().getType())); // Adds the configuration of each layer. for (final LayerConfiguration layerConf : neuralNetConf.getLayerList()) { neuralNetConfBuilder.addLayerConfiguration(createLayerConfiguration(layerConf)); } return; } /** * Registers command line parameters for driver. * @param cl */ public static void registerShortNameOfClass(final CommandLine cl) { cl.registerShortNameOfClass(ConfigurationPath.class); cl.registerShortNameOfClass(Delimiter.class); cl.registerShortNameOfClass(MaxIterations.class); } /** * @return the configuration for driver. */ public Configuration getDriverConfiguration() { return Tang.Factory.getTang().newConfigurationBuilder() .bindNamedParameter(NeuralNetworkESParameters.SerializedNeuralNetConf.class, serializedNeuralNetworkConfiguration) .bindNamedParameter(Delimiter.class, delimiter) .bindNamedParameter(MaxIterations.class, String.valueOf(maxIterations)) .bindNamedParameter(InputShape.class, inputShape).build(); } /** * @return {@code true} if this Neural Network application uses REEF Group Communication, {@code false} otherwise */ public boolean isGroupComm() { return this.groupComm; } }