org.deeplearning4j.datasets.canova.RecordReaderDataSetiteratorTest.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.datasets.canova.RecordReaderDataSetiteratorTest.java

Source

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.datasets.canova;

import org.apache.commons.io.FilenameUtils;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.records.reader.impl.CSVRecordReader;
import org.canova.api.records.reader.impl.CSVSequenceRecordReader;
import org.canova.api.split.FileSplit;
import org.canova.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;

import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import static org.junit.Assert.*;

/**
 * Created by agibsonccc on 3/6/15.
 */
public class RecordReaderDataSetiteratorTest {

    @Test
    public void testRecordReader() throws Exception {
        RecordReader recordReader = new CSVRecordReader();
        FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getFile());
        recordReader.initialize(csv);
        DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
        DataSet next = iter.next();
        assertEquals(34, next.numExamples());
    }

    @Test
    public void testSequenceRecordReader() throws Exception {
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabels_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader,
                labelReader, 1, 4, false);

        assertEquals(3, iter.inputColumns());
        assertEquals(4, iter.totalOutcomes());

        List<DataSet> dsList = new ArrayList<>();
        while (iter.hasNext()) {
            dsList.add(iter.next());
        }

        assertEquals(3, dsList.size()); //3 files
        for (int i = 0; i < 3; i++) {
            DataSet ds = dsList.get(i);
            INDArray features = ds.getFeatureMatrix();
            INDArray labels = ds.getLabels();
            assertEquals(1, features.size(0)); //1 example in mini-batch
            assertEquals(1, labels.size(0));
            assertEquals(3, features.size(1)); //3 values per line/time step
            assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector
            assertEquals(4, features.size(2)); //sequence length = 4
            assertEquals(4, labels.size(2));
        }

        //Check features vs. expected:
        INDArray expF0 = Nd4j.create(1, 3, 4);
        expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 2 }));
        expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 10, 11, 12 }));
        expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 20, 21, 22 }));
        expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 30, 31, 32 }));
        assertEquals(dsList.get(0).getFeatureMatrix(), expF0);

        INDArray expF1 = Nd4j.create(1, 3, 4);
        expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 100, 101, 102 }));
        expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 110, 111, 112 }));
        expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 120, 121, 122 }));
        expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 130, 131, 132 }));
        assertEquals(dsList.get(1).getFeatureMatrix(), expF1);

        INDArray expF2 = Nd4j.create(1, 3, 4);
        expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 200, 201, 202 }));
        expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 210, 211, 212 }));
        expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 220, 221, 222 }));
        expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 230, 231, 232 }));
        assertEquals(dsList.get(2).getFeatureMatrix(), expF2);

        //Check labels vs. expected:
        INDArray expL0 = Nd4j.create(1, 4, 4);
        expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 }));
        expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        expL0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 }));
        expL0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 }));
        assertEquals(dsList.get(0).getLabels(), expL0);

        INDArray expL1 = Nd4j.create(1, 4, 4);
        expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 }));
        expL1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 }));
        expL1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        expL1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 }));
        assertEquals(dsList.get(1).getLabels(), expL1);

        INDArray expL2 = Nd4j.create(1, 4, 4);
        expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 }));
        expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 }));
        expL2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 }));
        assertEquals(dsList.get(2).getLabels(), expL2);
    }

    @Test
    public void testSequenceRecordReaderRegression() throws Exception {
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequence_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader,
                labelReader, 1, 0, true);

        assertEquals(3, iter.inputColumns());
        assertEquals(3, iter.totalOutcomes());

        List<DataSet> dsList = new ArrayList<>();
        while (iter.hasNext()) {
            dsList.add(iter.next());
        }

        assertEquals(3, dsList.size()); //3 files
        for (int i = 0; i < 3; i++) {
            DataSet ds = dsList.get(i);
            INDArray features = ds.getFeatureMatrix();
            INDArray labels = ds.getLabels();
            assertArrayEquals(new int[] { 1, 3, 4 }, features.shape()); //1 examples, 3 values, 4 time steps
            assertArrayEquals(new int[] { 1, 3, 4 }, labels.shape());

            assertEquals(features, labels);
        }
    }

    @Test
    public void testSequenceRecordReaderReset() throws Exception {
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabels_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader,
                labelReader, 1, 4, false);

        assertEquals(3, iter.inputColumns());
        assertEquals(4, iter.totalOutcomes());

        int nResets = 5;
        for (int i = 0; i < nResets; i++) {
            iter.reset();
            int count = 0;
            while (iter.hasNext()) {
                DataSet ds = iter.next();
                INDArray features = ds.getFeatureMatrix();
                INDArray labels = ds.getLabels();
                assertArrayEquals(new int[] { 1, 3, 4 }, features.shape());
                assertArrayEquals(new int[] { 1, 4, 4 }, labels.shape());
                count++;
            }
            assertEquals(3, count);
        }
    }

    @Test
    public void testCSVLoadingRegression() throws Exception {
        int nLines = 30;
        int nFeatures = 5;
        int miniBatchSize = 10;
        int labelIdx = 0;

        String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "rr_csv_test_rand.csv");
        double[][] data = makeRandomCSV(path, nLines, nFeatures);
        RecordReader testReader = new CSVRecordReader();
        testReader.initialize(new FileSplit(new File(path)));

        DataSetIterator iter = new RecordReaderDataSetIterator(testReader, null, miniBatchSize, labelIdx, 1, true);
        int miniBatch = 0;
        while (iter.hasNext()) {
            DataSet test = iter.next();
            INDArray features = test.getFeatureMatrix();
            INDArray labels = test.getLabels();
            assertArrayEquals(new int[] { miniBatchSize, nFeatures }, features.shape());
            assertArrayEquals(new int[] { miniBatchSize, 1 }, labels.shape());

            int startRow = miniBatch * miniBatchSize;
            for (int i = 0; i < miniBatchSize; i++) {
                double labelExp = data[startRow + i][labelIdx];
                double labelAct = labels.getDouble(i);
                assertEquals(labelExp, labelAct, 1e-5f);

                int featureCount = 0;
                for (int j = 0; j < nFeatures + 1; j++) {
                    if (j == labelIdx)
                        continue;
                    double featureExp = data[startRow + i][j];
                    double featureAct = features.getDouble(i, featureCount++);
                    assertEquals(featureExp, featureAct, 1e-5f);
                }
            }

            miniBatch++;
        }
        assertEquals(nLines / miniBatchSize, miniBatch);
    }

    public static double[][] makeRandomCSV(String tempFile, int nLines, int nFeatures) {
        File temp = new File(tempFile);
        temp.deleteOnExit();
        Random rand = new Random(12345);

        double[][] dArr = new double[nLines][nFeatures + 1];

        try (PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter(temp)))) {
            for (int i = 0; i < nLines; i++) {
                dArr[i][0] = rand.nextDouble(); //First column: label
                out.print(dArr[i][0]);
                for (int j = 0; j < nFeatures; j++) {
                    dArr[i][j + 1] = rand.nextDouble();
                    out.print("," + dArr[i][j + 1]);
                }
                out.println();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return dArr;
    }

    @Test
    public void testVariableLengthSequence() throws Exception {
        ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
        String featuresPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");
        resource = new ClassPathResource("csvsequencelabelsShort_0.txt");
        String labelsPath = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
        featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
        SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
        featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
        labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

        SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader,
                labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);

        SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2,
                labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

        assertEquals(3, iterAlignStart.inputColumns());
        assertEquals(4, iterAlignStart.totalOutcomes());

        assertEquals(3, iterAlignEnd.inputColumns());
        assertEquals(4, iterAlignEnd.totalOutcomes());

        List<DataSet> dsListAlignStart = new ArrayList<>();
        while (iterAlignStart.hasNext()) {
            dsListAlignStart.add(iterAlignStart.next());
        }

        List<DataSet> dsListAlignEnd = new ArrayList<>();
        while (iterAlignEnd.hasNext()) {
            dsListAlignEnd.add(iterAlignEnd.next());
        }

        assertEquals(3, dsListAlignStart.size()); //3 files
        assertEquals(3, dsListAlignEnd.size()); //3 files

        for (int i = 0; i < 3; i++) {
            DataSet ds = dsListAlignStart.get(i);
            INDArray features = ds.getFeatureMatrix();
            INDArray labels = ds.getLabels();
            assertEquals(1, features.size(0)); //1 example in mini-batch
            assertEquals(1, labels.size(0));
            assertEquals(3, features.size(1)); //3 values per line/time step
            assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector
            assertEquals(4, features.size(2)); //sequence length = 4
            assertEquals(4, labels.size(2));

            DataSet ds2 = dsListAlignEnd.get(i);
            features = ds2.getFeatureMatrix();
            labels = ds2.getLabels();
            assertEquals(1, features.size(0)); //1 example in mini-batch
            assertEquals(1, labels.size(0));
            assertEquals(3, features.size(1)); //3 values per line/time step
            assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector
            assertEquals(4, features.size(2)); //sequence length = 4
            assertEquals(4, labels.size(2));
        }

        //Check features vs. expected:
        //Here: labels always longer than features -> same features for align start and align end
        INDArray expF0 = Nd4j.create(1, 3, 4);
        expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 2 }));
        expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 10, 11, 12 }));
        expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 20, 21, 22 }));
        expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 30, 31, 32 }));
        assertEquals(dsListAlignStart.get(0).getFeatureMatrix(), expF0);
        assertEquals(dsListAlignEnd.get(0).getFeatureMatrix(), expF0);

        INDArray expF1 = Nd4j.create(1, 3, 4);
        expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 100, 101, 102 }));
        expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 110, 111, 112 }));
        expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 120, 121, 122 }));
        expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 130, 131, 132 }));
        assertEquals(dsListAlignStart.get(1).getFeatureMatrix(), expF1);
        assertEquals(dsListAlignEnd.get(1).getFeatureMatrix(), expF1);

        INDArray expF2 = Nd4j.create(1, 3, 4);
        expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 200, 201, 202 }));
        expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 210, 211, 212 }));
        expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 220, 221, 222 }));
        expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 230, 231, 232 }));
        assertEquals(dsListAlignStart.get(2).getFeatureMatrix(), expF2);
        assertEquals(dsListAlignEnd.get(2).getFeatureMatrix(), expF2);

        //Check features mask array:
        INDArray featuresMaskExpected = Nd4j.ones(1, 4); //1 example, 4 values: same for both start/end align here
        for (int i = 0; i < 3; i++) {
            INDArray featuresMaskStart = dsListAlignStart.get(i).getFeaturesMaskArray();
            INDArray featuresMaskEnd = dsListAlignEnd.get(i).getFeaturesMaskArray();
            assertEquals(featuresMaskExpected, featuresMaskStart);
            assertEquals(featuresMaskExpected, featuresMaskEnd);
        }

        //Check labels vs. expected:
        //First: aligning start
        INDArray expL0 = Nd4j.create(1, 4, 4);
        expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 }));
        expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        assertEquals(expL0, dsListAlignStart.get(0).getLabels());

        INDArray expL1 = Nd4j.create(1, 4, 4);
        expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        assertEquals(expL1, dsListAlignStart.get(1).getLabels());

        INDArray expL2 = Nd4j.create(1, 4, 4);
        expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 }));
        expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 }));
        expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        assertEquals(expL2, dsListAlignStart.get(2).getLabels());

        //Second: align end
        INDArray expL0end = Nd4j.create(1, 4, 4);
        expL0end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 }));
        expL0end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        assertEquals(expL0end, dsListAlignEnd.get(0).getLabels());

        INDArray expL1end = Nd4j.create(1, 4, 4);
        expL1end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        assertEquals(expL1end, dsListAlignEnd.get(1).getLabels());

        INDArray expL2end = Nd4j.create(1, 4, 4);
        expL2end.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 }));
        expL2end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 }));
        expL2end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 }));
        assertEquals(expL2end, dsListAlignEnd.get(2).getLabels());

        //Check labels mask array
        INDArray[] labelsMaskExpectedStart = new INDArray[] {
                Nd4j.create(new float[] { 1, 1, 0, 0 }, new int[] { 1, 4 }),
                Nd4j.create(new float[] { 1, 0, 0, 0 }, new int[] { 1, 4 }),
                Nd4j.create(new float[] { 1, 1, 1, 0 }, new int[] { 1, 4 }) };
        INDArray[] labelsMaskExpectedEnd = new INDArray[] {
                Nd4j.create(new float[] { 0, 0, 1, 1 }, new int[] { 1, 4 }),
                Nd4j.create(new float[] { 0, 0, 0, 1 }, new int[] { 1, 4 }),
                Nd4j.create(new float[] { 0, 1, 1, 1 }, new int[] { 1, 4 }) };

        for (int i = 0; i < 3; i++) {
            INDArray labelsMaskStart = dsListAlignStart.get(i).getLabelsMaskArray();
            INDArray labelsMaskEnd = dsListAlignEnd.get(i).getLabelsMaskArray();
            assertEquals(labelsMaskExpectedStart[i], labelsMaskStart);
            assertEquals(labelsMaskExpectedEnd[i], labelsMaskEnd);
        }
    }

    @Test
    public void testSequenceRecordReaderSingleReader() throws Exception {
        ClassPathResource resource = new ClassPathResource("csvsequenceSingle_0.txt");
        String path = resource.getFile().getAbsolutePath().replaceAll("0", "%d");

        SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
        reader.initialize(new NumberedFileInputSplit(path, 0, 2));
        SequenceRecordReaderDataSetIterator iteratorClassification = new SequenceRecordReaderDataSetIterator(reader,
                1, 3, 0, false);

        SequenceRecordReader reader2 = new CSVSequenceRecordReader(1, ",");
        reader2.initialize(new NumberedFileInputSplit(path, 0, 2));
        SequenceRecordReaderDataSetIterator iteratorRegression = new SequenceRecordReaderDataSetIterator(reader2, 1,
                3, 0, true);

        INDArray expF0 = Nd4j.create(1, 2, 4);
        expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 2 }));
        expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 11, 12 }));
        expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 21, 22 }));
        expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 31, 32 }));

        INDArray expF1 = Nd4j.create(1, 2, 4);
        expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 101, 102 }));
        expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 111, 112 }));
        expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 121, 122 }));
        expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 131, 132 }));

        INDArray expF2 = Nd4j.create(1, 2, 4);
        expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 201, 202 }));
        expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 211, 212 }));
        expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 221, 222 }));
        expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 231, 232 }));

        INDArray[] expF = new INDArray[] { expF0, expF1, expF2 };

        //Expected out for classification:
        INDArray expOut0 = Nd4j.create(1, 3, 4);
        expOut0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 0, 0 }));
        expOut0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 1, 0 }));
        expOut0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 1 }));
        expOut0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 1, 0, 0 }));

        INDArray expOut1 = Nd4j.create(1, 3, 4);
        expOut1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0 }));
        expOut1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 1 }));
        expOut1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 1, 0, 0 }));
        expOut1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 1 }));

        INDArray expOut2 = Nd4j.create(1, 3, 4);
        expOut2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0 }));
        expOut2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 1, 0, 0 }));
        expOut2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 1, 0 }));
        expOut2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 1 }));
        INDArray[] expOutClassification = new INDArray[] { expOut0, expOut1, expOut2 };

        //Expected out for regression:
        INDArray expOutR0 = Nd4j.create(1, 1, 4);
        expOutR0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0 }));
        expOutR0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 1 }));
        expOutR0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 2 }));
        expOutR0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0 }));

        INDArray expOutR1 = Nd4j.create(1, 1, 4);
        expOutR1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1 }));
        expOutR1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 2 }));
        expOutR1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0 }));
        expOutR1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 2 }));

        INDArray expOutR2 = Nd4j.create(1, 1, 4);
        expOutR2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1 }));
        expOutR2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0 }));
        expOutR2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 1 }));
        expOutR2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 2 }));
        INDArray[] expOutRegression = new INDArray[] { expOutR0, expOutR1, expOutR2 };

        int countC = 0;
        while (iteratorClassification.hasNext()) {
            DataSet ds = iteratorClassification.next();
            INDArray f = ds.getFeatures();
            INDArray l = ds.getLabels();
            assertNull(ds.getFeaturesMaskArray());
            assertNull(ds.getLabelsMaskArray());

            assertArrayEquals(new int[] { 1, 2, 4 }, f.shape());
            assertArrayEquals(new int[] { 1, 3, 4 }, l.shape()); //One-hot representation

            assertEquals(expF[countC], f);
            assertEquals(expOutClassification[countC++], l);
        }
        assertEquals(3, countC);
        assertEquals(3, iteratorClassification.totalOutcomes());

        int countF = 0;
        while (iteratorRegression.hasNext()) {
            DataSet ds = iteratorRegression.next();
            INDArray f = ds.getFeatures();
            INDArray l = ds.getLabels();
            assertNull(ds.getFeaturesMaskArray());
            assertNull(ds.getLabelsMaskArray());

            assertArrayEquals(new int[] { 1, 2, 4 }, f.shape());
            assertArrayEquals(new int[] { 1, 1, 4 }, l.shape()); //Regression (single output)

            assertEquals(expF[countF], f);
            assertEquals(expOutRegression[countF++], l);
        }
        assertEquals(3, countF);
        assertEquals(1, iteratorRegression.totalOutcomes());
    }
}