Java tutorial
/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dataset; import ml.shifu.shifu.core.dtrain.nn.ActivationLeakyReLU; import ml.shifu.shifu.core.dtrain.nn.ActivationPTANH; import ml.shifu.shifu.core.dtrain.nn.ActivationReLU; import ml.shifu.shifu.core.dtrain.nn.ActivationSwish; import org.apache.commons.lang.StringUtils; import org.encog.engine.network.activation.ActivationFunction; import org.encog.neural.flat.FlatNetwork; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.PersistBasicNetwork; import org.encog.persist.*; import org.encog.util.csv.CSVFormat; import java.io.*; import java.util.*; import java.util.Map.Entry; /** * Support {@link BasicFloatNetwork} serialization and de-serialization. This is copied from {@link PersistBasicNetwork} * and only {@link #getPersistClassString()} is changed to 'BasicFloatNetwork'. * * <p> * Because of all final methods in {@link PersistBasicNetwork}, we have to copy code while not take extension. */ public class PersistBasicFloatNetwork implements EncogPersistor { /** * {@inheritDoc} */ @Override public final int getFileVersion() { return 1; } /** * {@inheritDoc} */ @Override public final String getPersistClassString() { return "BasicFloatNetwork"; } /** * {@inheritDoc} */ @Override public final Object read(final InputStream is) { final BasicFloatNetwork result = new BasicFloatNetwork(); final FlatNetwork flat = new FlatNetwork(); final EncogReadHelper in = new EncogReadHelper(is); EncogFileSection section; while ((section = in.readNextSection()) != null) { if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("PARAMS")) { final Map<String, String> params = section.parseParams(); result.getProperties().putAll(params); } if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("NETWORK")) { final Map<String, String> params = section.parseParams(); flat.setBeginTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_BEGIN_TRAINING)); flat.setConnectionLimit(EncogFileSection.parseDouble(params, BasicNetwork.TAG_CONNECTION_LIMIT)); flat.setContextTargetOffset( EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_OFFSET)); flat.setContextTargetSize( EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_SIZE)); flat.setEndTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_END_TRAINING)); flat.setHasContext(EncogFileSection.parseBoolean(params, BasicNetwork.TAG_HAS_CONTEXT)); flat.setInputCount(EncogFileSection.parseInt(params, PersistConst.INPUT_COUNT)); flat.setLayerCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_COUNTS)); flat.setLayerFeedCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_FEED_COUNTS)); flat.setLayerContextCount( EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_CONTEXT_COUNT)); flat.setLayerIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_INDEX)); flat.setLayerOutput(EncogFileSection.parseDoubleArray(params, PersistConst.OUTPUT)); flat.setLayerSums(new double[flat.getLayerOutput().length]); flat.setOutputCount(EncogFileSection.parseInt(params, PersistConst.OUTPUT_COUNT)); flat.setWeightIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_WEIGHT_INDEX)); flat.setWeights(EncogFileSection.parseDoubleArray(params, PersistConst.WEIGHTS)); flat.setBiasActivation(EncogFileSection.parseDoubleArray(params, BasicNetwork.TAG_BIAS_ACTIVATION)); } else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("ACTIVATION")) { int index = 0; flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]); for (final String line : section.getLines()) { ActivationFunction af = null; final List<String> cols = EncogFileSection.splitColumns(line); String name = "org.encog.engine.network.activation." + cols.get(0); if (cols.get(0).equals("ActivationReLU")) { name = "ml.shifu.shifu.core.dtrain.nn.ActivationReLU"; } else if (cols.get(0).equals("ActivationLeakyReLU")) { name = "ml.shifu.shifu.core.dtrain.nn.ActivationLeakyReLU"; } else if (cols.get(0).equals("ActivationSwish")) { name = "ml.shifu.shifu.core.dtrain.nn.ActivationSwish"; } else if (cols.get(0).equals("ActivationPTANH")) { name = "ml.shifu.shifu.core.dtrain.nn.ActivationPTANH"; } try { final Class<?> clazz = Class.forName(name); af = (ActivationFunction) clazz.newInstance(); } catch (final ClassNotFoundException e) { throw new PersistError(e); } catch (final InstantiationException e) { throw new PersistError(e); } catch (final IllegalAccessException e) { throw new PersistError(e); } for (int i = 0; i < af.getParamNames().length; i++) { af.setParam(i, CSVFormat.EG_FORMAT.parse(cols.get(i + 1))); } flat.getActivationFunctions()[index++] = af; } } else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("SUBSET")) { final Map<String, String> params = section.parseParams(); String subsetStr = params.get("SUBSETFEATURES"); if (StringUtils.isBlank(subsetStr)) { result.setFeatureSet(null); } else { String[] splits = subsetStr.split(","); Set<Integer> subFeatures = new HashSet<Integer>(); for (String split : splits) { int featureIndex = Integer.parseInt(split); subFeatures.add(featureIndex); } result.setFeatureSet(subFeatures); } } } result.getStructure().setFlat(flat); return result; } /** * {@inheritDoc} */ @Override public final void save(final OutputStream os, final Object obj) { final EncogWriteHelper out = new EncogWriteHelper(os); final BasicFloatNetwork net = (BasicFloatNetwork) obj; final FlatNetwork flat = net.getStructure().getFlat(); out.addSection("BASIC"); out.addSubSection("PARAMS"); out.addProperties(net.getProperties()); out.addSubSection("NETWORK"); out.writeProperty(BasicNetwork.TAG_BEGIN_TRAINING, flat.getBeginTraining()); out.writeProperty(BasicNetwork.TAG_CONNECTION_LIMIT, flat.getConnectionLimit()); out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_OFFSET, flat.getContextTargetOffset()); out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_SIZE, flat.getContextTargetSize()); out.writeProperty(BasicNetwork.TAG_END_TRAINING, flat.getEndTraining()); out.writeProperty(BasicNetwork.TAG_HAS_CONTEXT, flat.getHasContext()); out.writeProperty(PersistConst.INPUT_COUNT, flat.getInputCount()); out.writeProperty(BasicNetwork.TAG_LAYER_COUNTS, flat.getLayerCounts()); out.writeProperty(BasicNetwork.TAG_LAYER_FEED_COUNTS, flat.getLayerFeedCounts()); out.writeProperty(BasicNetwork.TAG_LAYER_CONTEXT_COUNT, flat.getLayerContextCount()); out.writeProperty(BasicNetwork.TAG_LAYER_INDEX, flat.getLayerIndex()); out.writeProperty(PersistConst.OUTPUT, flat.getLayerOutput()); out.writeProperty(PersistConst.OUTPUT_COUNT, flat.getOutputCount()); out.writeProperty(BasicNetwork.TAG_WEIGHT_INDEX, flat.getWeightIndex()); out.writeProperty(PersistConst.WEIGHTS, flat.getWeights()); out.writeProperty(BasicNetwork.TAG_BIAS_ACTIVATION, flat.getBiasActivation()); out.addSubSection("ACTIVATION"); for (final ActivationFunction af : flat.getActivationFunctions()) { out.addColumn(af.getClass().getSimpleName()); for (int i = 0; i < af.getParams().length; i++) { out.addColumn(af.getParams()[i]); } out.writeLine(); } out.addSubSection("SUBSET"); Set<Integer> featureList = net.getFeatureSet(); if (featureList == null || featureList.size() == 0) { out.writeProperty("SUBSETFEATURES", ""); } else { String subFeaturesStr = StringUtils.join(featureList, ","); out.writeProperty("SUBSETFEATURES", subFeaturesStr); } out.flush(); } public BasicFloatNetwork readNetwork(final DataInput in) throws IOException { final BasicFloatNetwork result = new BasicFloatNetwork(); final FlatNetwork flat = new FlatNetwork(); // read properties Map<String, String> properties = new HashMap<String, String>(); int size = in.readInt(); for (int i = 0; i < size; i++) { properties.put(ml.shifu.shifu.core.dtrain.StringUtils.readString(in), ml.shifu.shifu.core.dtrain.StringUtils.readString(in)); } result.getProperties().putAll(properties); // read fields flat.setBeginTraining(in.readInt()); flat.setConnectionLimit(in.readDouble()); flat.setContextTargetOffset(readIntArray(in)); flat.setContextTargetSize(readIntArray(in)); flat.setEndTraining(in.readInt()); flat.setHasContext(in.readBoolean()); flat.setInputCount(in.readInt()); flat.setLayerCounts(readIntArray(in)); flat.setLayerFeedCounts(readIntArray(in)); flat.setLayerContextCount(readIntArray(in)); flat.setLayerIndex(readIntArray(in)); flat.setLayerOutput(readDoubleArray(in)); flat.setOutputCount(in.readInt()); flat.setLayerSums(new double[flat.getLayerOutput().length]); flat.setWeightIndex(readIntArray(in)); flat.setWeights(readDoubleArray(in)); flat.setBiasActivation(readDoubleArray(in)); // read activations flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]); int acSize = in.readInt(); for (int i = 0; i < acSize; i++) { String name = ml.shifu.shifu.core.dtrain.StringUtils.readString(in); if (name.equals("ActivationReLU")) { name = ActivationReLU.class.getName(); } else if (name.equals("ActivationLeakyReLU")) { name = ActivationLeakyReLU.class.getName(); } else if (name.equals("ActivationSwish")) { name = ActivationSwish.class.getName(); } else if (name.equals("ActivationPTANH")) { name = ActivationPTANH.class.getName(); } else { name = "org.encog.engine.network.activation." + name; } ActivationFunction af = null; try { final Class<?> clazz = Class.forName(name); af = (ActivationFunction) clazz.newInstance(); } catch (final ClassNotFoundException e) { throw new PersistError(e); } catch (final InstantiationException e) { throw new PersistError(e); } catch (final IllegalAccessException e) { throw new PersistError(e); } double[] params = readDoubleArray(in); for (int j = 0; j < params.length; j++) { af.setParam(j, params[j]); } flat.getActivationFunctions()[i] = af; } // read subset int subsetSize = in.readInt(); Set<Integer> featureList = new HashSet<Integer>(); for (int i = 0; i < subsetSize; i++) { featureList.add(in.readInt()); } result.setFeatureSet(featureList); result.getStructure().setFlat(flat); return result; } public void saveNetwork(DataOutput out, final BasicFloatNetwork network) throws IOException { final FlatNetwork flat = network.getStructure().getFlat(); // write general properties Map<String, String> properties = network.getProperties(); if (properties == null) { out.writeInt(0); } else { out.writeInt(properties.size()); for (Entry<String, String> entry : properties.entrySet()) { ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getKey()); ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getValue()); } } // write fields values in BasicFloatNetwork out.writeInt(flat.getBeginTraining()); out.writeDouble(flat.getConnectionLimit()); writeIntArray(out, flat.getContextTargetOffset()); writeIntArray(out, flat.getContextTargetSize()); out.writeInt(flat.getEndTraining()); out.writeBoolean(flat.getHasContext()); out.writeInt(flat.getInputCount()); writeIntArray(out, flat.getLayerCounts()); writeIntArray(out, flat.getLayerFeedCounts()); writeIntArray(out, flat.getLayerContextCount()); writeIntArray(out, flat.getLayerIndex()); writeDoubleArray(out, flat.getLayerOutput()); out.writeInt(flat.getOutputCount()); writeIntArray(out, flat.getWeightIndex()); writeDoubleArray(out, flat.getWeights()); writeDoubleArray(out, flat.getBiasActivation()); // write activation list out.writeInt(flat.getActivationFunctions().length); for (final ActivationFunction af : flat.getActivationFunctions()) { ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, af.getClass().getSimpleName()); writeDoubleArray(out, af.getParams()); } // write sub sets Set<Integer> featureList = network.getFeatureSet(); if (featureList == null || featureList.size() == 0) { out.writeInt(0); } else { out.writeInt(featureList.size()); for (Integer integer : featureList) { out.writeInt(integer); } } } private int[] readIntArray(DataInput in) throws IOException { int size = in.readInt(); int[] array = new int[size]; for (int i = 0; i < size; i++) { array[i] = in.readInt(); } return array; } private double[] readDoubleArray(DataInput in) throws IOException { int size = in.readInt(); double[] array = new double[size]; for (int i = 0; i < size; i++) { array[i] = in.readDouble(); } return array; } private void writeIntArray(DataOutput out, int[] array) throws IOException { if (array == null) { out.writeInt(0); } else { out.writeInt(array.length); for (int i : array) { out.writeInt(i); } } } private void writeDoubleArray(DataOutput out, double[] array) throws IOException { if (array == null) { out.writeInt(0); } else { out.writeInt(array.length); for (double d : array) { out.writeDouble(d); } } } }