Java tutorial
/******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.examples; import java.io.DataInputStream; import java.io.IOException; import java.net.URI; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Stack; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.io.IOUtils; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.datavec.api.records.reader.impl.FileRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.writable.Writable; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.examples.conf.Constants; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author: Ousmane A. Dia */ public class StackSequenceRecordReader extends FileRecordReader { private Stack<List<Path>> sequenceStack = new Stack<List<Path>>(); private FileSystem fs; private List<Double> featureMasks, labelMasks; private static final long serialVersionUID = 2949161683514893388L; private CSVSequenceRecordReader csvRecordReader = new CSVSequenceRecordReader(); private List<String> timeSteps = new ArrayList<String>(); private int timeIndex = Constants.TIME_INDEX; private static final Logger LOG = LoggerFactory.getLogger(StackSequenceRecordReader.class); public StackSequenceRecordReader(FileSystem fs, int startSeq, int endSeq) { if (startSeq > endSeq) throw new RuntimeException("Start is greater than End."); this.fs = fs; for (int i = startSeq; i <= endSeq; i++) { timeSteps.add(String.valueOf(i)); } } public void newRecord(Stack<Path> pathStack) { newRecord(pathStack, true); } public void newRecord(Stack<Path> pathStack, boolean reverse) { List<Path> paths = new ArrayList<Path>(); if (pathStack.isEmpty()) return; while (!pathStack.isEmpty()) { paths.add(pathStack.pop()); } if (reverse) Collections.reverse(paths); this.sequenceStack.push(paths); } public boolean hasNext() { return !sequenceStack.isEmpty() && !sequenceStack.peek().isEmpty(); } @Override public List<Writable> next() { featureMasks = new ArrayList<Double>(Collections.nCopies(timeSteps.size(), 0.0)); labelMasks = new ArrayList<Double>(Collections.nCopies(timeSteps.size(), 0.0)); List<Writable> writable = new ArrayList<Writable>(); List<Path> paths = sequenceStack.pop(); DataInputStream in = null; for (Path p : paths) { try { in = fs.open(p); URI uri = p.toUri(); String pathParts[] = uri.toString().split("_"); String currStep = pathParts == null || pathParts.length < 3 ? "" : pathParts[2]; int index = timeSteps.indexOf(currStep); index = index < 0 ? timeSteps.size() - 1 : index; List<String> seq = new ArrayList<String>(); List<List<Writable>> steps = csvRecordReader.sequenceRecord(uri, in); if (!CollectionUtils.isEmpty(steps)) { for (List<Writable> instance : steps) { if (timeIndex >= 1 && timeIndex <= instance.size()) { seq.add(instance.get(timeIndex - 1).toString()); if (timeIndex - 1 > 0) writable.addAll(instance.subList(0, timeIndex - 1)); writable.addAll(instance.subList(timeIndex, instance.size())); } else { writable.addAll(instance); } } if (!seq.isEmpty()) { String maxIndex = seq.get(0); for (int i = 0; i < seq.size(); i++) { if (Integer.parseInt(seq.get(i)) > Integer.parseInt(maxIndex)) maxIndex = seq.get(i); featureMasks.set(timeSteps.indexOf(seq.get(i)), 1.0); } labelMasks.set(timeSteps.indexOf(maxIndex), timeSteps.get(timeSteps.size() - 1) == maxIndex || paths.size() == 1 ? 1.0 : 0.0); } else { featureMasks.set(index, 1.0); labelMasks.set(index, timeSteps.get(timeSteps.size() - 1) == currStep || paths.size() == 1 ? 1.0 : 0.0); } } } catch (IOException e) { LOG.error(e.getLocalizedMessage(), e); } finally { if (in != null) { IOUtils.closeQuietly(in); } } } return writable; } private List<List<Writable>> records(List<List<Double>> lMasks, List<List<Double>> rMasks) { List<List<Writable>> writables = new ArrayList<List<Writable>>(); while (hasNext()) { List<Writable> w = next(); if (w.isEmpty()) continue; writables.add(w); lMasks.add(featureMasks); rMasks.add(labelMasks); } return writables; } private MultiDataSet zeroMDS(int numFeatures, int numLabels) { return new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { Nd4j.create(new double[numFeatures]) }, new INDArray[] { Nd4j.create(new double[numLabels]) }, new INDArray[] { Nd4j.create(new double[1]) }, new INDArray[] { Nd4j.create(new double[1]) }); } public MultiDataSet toMultiDataSet(int numFeatures, int numLabels) { List<List<Double>> featureMasks = new ArrayList<List<Double>>(); List<List<Double>> labelMasks = new ArrayList<List<Double>>(); List<List<Writable>> records = records(featureMasks, labelMasks); List<MultiDataSet> list = new ArrayList<MultiDataSet>(records.size()); int currHotIndex = 0; for (int i = 0; i < records.size(); i++) { Pair<INDArray[], INDArray[]> pair = toPair(records.get(i), numFeatures, numLabels); for (int j = 0; j < timeSteps.size(); j++) { if (featureMasks.get(i).get(j) != 1.0) { list.add(zeroMDS(numFeatures, numLabels)); } else { INDArray feature = pair.getFirst()[currHotIndex]; INDArray label = pair.getSecond()[currHotIndex++]; INDArray fMask = Nd4j.create(new double[] { featureMasks.get(i).get(j) }); INDArray lMask = Nd4j.create(new double[] { labelMasks.get(i).get(j) }); list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { feature }, new INDArray[] { label }, new INDArray[] { fMask }, new INDArray[] { lMask })); } } currHotIndex = 0; } return list.isEmpty() ? null : org.nd4j.linalg.dataset.MultiDataSet.merge(list); } private Pair<INDArray[], INDArray[]> toPair(Collection<Writable> record, int numFeatures, int numLabels) { Iterator<Writable> writables = record.iterator(); Writable firstWritable = writables.next(); INDArray vector1 = Nd4j.create(numFeatures); INDArray vector2 = Nd4j.create(numLabels); INDArray[] array1 = new INDArray[record.size() / (numFeatures + numLabels)]; INDArray[] array2 = new INDArray[record.size() / (numFeatures + numLabels)]; vector1.putScalar(0, firstWritable.toDouble()); int count1 = 1, count2 = 0, i = 0; while (writables.hasNext()) { Writable w = writables.next(); if (count1 < numFeatures) { double val = Double.isNaN(w.toDouble()) ? 0.0 : w.toDouble(); vector1.putScalar(count1++, val); } else { if (count2 < numLabels) { double val = Double.isNaN(w.toDouble()) ? 0.0 : w.toDouble(); vector2.putScalar(count2++, val); } } if (count1 == numFeatures && count2 == numLabels) { array1[i] = vector1; array2[i++] = vector2; count1 = 0; count2 = 0; } } return new Pair<INDArray[], INDArray[]>(array1, array2); } }