org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.nn.modelimport.keras.e2e.KerasModelEndToEndTest.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.nn.modelimport.keras.e2e;

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
 * Unit tests for end-to-end Keras model import.
 *
 * @author dave@skymind.io, Max Pumperla
 */
@Slf4j
public class KerasModelEndToEndTest {
    private static final String GROUP_ATTR_INPUTS = "inputs";
    private static final String GROUP_ATTR_OUTPUTS = "outputs";
    private static final String GROUP_PREDICTIONS = "predictions";
    private static final String GROUP_ACTIVATIONS = "activations";
    private static final String TEMP_OUTPUTS_FILENAME = "tempOutputs";
    private static final String TEMP_MODEL_FILENAME = "tempModel";
    private static final String H5_EXTENSION = ".h5";
    private static final double EPS = 1E-5;

    private static final boolean SKIP_GRAD_CHECKS = true;

    @Rule
    public final TemporaryFolder testDir = new TemporaryFolder();

    @Test(expected = FileNotFoundException.class)
    public void fileNotFoundEndToEnd() throws Exception {
        String modelPath = "modelimport/keras/examples/foo/bar.h5";
        importEndModelTest(modelPath, null, true, true, false);
    }

    /**
     * MNIST MLP tests
     */
    @Test
    public void importMnistMlpTfKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    @Test
    public void importMnistMlpThKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, false, true, false);
    }

    @Test
    public void importMnistMlpTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    @Test
    public void importMnistMlpReshapeTfKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, true);
    }

    /**
     * MNIST CNN tests
     */
    @Test
    public void importMnistCnnTfKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, false, false);
    }

    @Test
    public void importMnistCnnThKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, false, true, true);
    }

    @Test
    public void importMnistCnnTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, true);
    }

    /**
     * IMDB Embedding and LSTM test
     */
    @Test
    public void importImdbLstmTfKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    @Test
    public void importImdbLstmThKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    @Test
    public void importImdbLstmTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    @Test
    public void importImdbLstmThKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, false, true, false);
    }

    /**
     * IMDB LSTM fasttext
     */
    // TODO: prediction checks fail due to globalpooling for fasttext, very few grads fail as well
    @Test
    public void importImdbFasttextTfKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, false, false, false);
    }

    @Test
    public void importImdbFasttextThKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, false, false, false);
    }

    @Test
    public void importImdbFasttextTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, false, false);
    }

    /**
     * Simple LSTM (return sequences = false) into Dense layer test
     */
    @Test
    public void importSimpleLstmTfKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    @Test
    public void importSimpleLstmThKeras1() throws Exception {
        String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    @Test
    public void importSimpleLstmTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, false, false);
    }

    /**
     * Simple LSTM (return sequences = true) into flatten into Dense layer test
     */
    @Test
    public void importSimpleFlattenLstmTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/"
                + "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    /**
     * Simple RNN (return sequences = true) into flatten into Dense layer test
     */
    @Test
    public void importSimpleFlattenRnnTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/"
                + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    /**
     * Simple RNN (return sequences = false) into Dense layer test
     */
    @Test
    public void importSimpleRnnTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/simple_rnn/"
                + "simple_rnn_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, true, false);
    }

    /**
     * CNN without bias test
     */
    @Test
    public void importCnnNoBiasTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5";
        String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5";
        importEndModelTest(modelPath, inputsOutputPath, true, false, true);
    }

    /**
     * GAN import tests
     */
    @Test
    public void importDcganMnistDiscriminator() throws Exception {
        importSequentialModelH5Test("modelimport/keras/examples/mnist_dcgan/dcgan_discriminator_epoch_50.h5");
    }

    @Test
    public void importDcganMnistGenerator() throws Exception {
        importSequentialModelH5Test("modelimport/keras/examples/mnist_dcgan/dcgan_generator_epoch_50.h5");
    }

    /**
     * Auxillary classifier GAN import test
     */
    @Test
    public void importAcganDiscriminator() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test(
                "modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5");
        INDArray input = Nd4j.create(10, 1, 28, 28);
        INDArray[] output = model.output(input);
    }

    @Test
    public void importAcganGenerator() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test(
                "modelimport/keras/examples/acgan/acgan_generator_1_epochs.h5");
        //System.out.println(model.summary()) ;
        INDArray latent = Nd4j.create(10, 100);
        INDArray label = Nd4j.create(10, 1);
        INDArray[] output = model.output(latent, label);
    }

    @Test
    public void importAcganCombined() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test(
                "modelimport/keras/examples/acgan/acgan_combined_1_epochs.h5");
        // TODO: imports, but incorrectly. Has only one input, should have two.
    }

    /**
     * Deep convolutional GAN import test
     */
    @Test
    public void importDcganDiscriminator() throws Exception {
        importSequentialModelH5Test("modelimport/keras/examples/gans/dcgan_discriminator.h5");
    }

    @Test
    public void importDcganGenerator() throws Exception {
        importSequentialModelH5Test("modelimport/keras/examples/gans/dcgan_generator.h5");
    }

    /**
     * Wasserstein GAN import test
     */
    @Test
    public void importWganDiscriminator() throws Exception {
        for (int i = 0; i < 100; i++) {
            // run a few times to make sure HDF5 doesn't crash
            importSequentialModelH5Test("modelimport/keras/examples/gans/wgan_discriminator.h5");
        }
    }

    @Test
    public void importWganGenerator() throws Exception {
        importSequentialModelH5Test("modelimport/keras/examples/gans/wgan_generator.h5");
    }

    @Test
    public void importCnn1d() throws Exception {
        importSequentialModelH5Test("modelimport/keras/examples/cnn1d/cnn1d_flatten_tf_keras2.h5");
    }

    /**
     * DGA classifier test
     */
    @Test
    public void importDgaClassifier() throws Exception {
        importSequentialModelH5Test("modelimport/keras/examples/dga_classifier/keras2_dga_classifier_tf_model.h5");
    }

    /**
     * Reshape flat input into 3D to fit into an LSTM model
     */
    @Test
    public void importFlatIntoLSTM() throws Exception {
        importFunctionalModelH5Test("modelimport/keras/examples/reshape_to_rnn/reshape_model.h5");
    }

    /**
     * Functional LSTM test
     */
    @Test
    public void importFunctionalLstmTfKeras2() throws Exception {
        String modelPath = "modelimport/keras/examples/functional_lstm/lstm_functional_tf_keras_2.h5";

        // No training enabled
        ComputationGraph graphNoTrain = importFunctionalModelH5Test(modelPath, null, false);
        System.out.println(graphNoTrain.summary());

        // Training enabled
        ComputationGraph graph = importFunctionalModelH5Test(modelPath, null, true);
        System.out.println(graph.summary());

        // Make predictions
        int miniBatch = 32;
        INDArray input = Nd4j.ones(miniBatch, 4, 10);
        INDArray[] out = graph.output(input);

        // Fit model
        graph.fit(new INDArray[] { input }, out);
    }

    /**
     * U-Net
     */
    @Test
    public void importUnetTfKeras2() throws Exception {
        importFunctionalModelH5Test("modelimport/keras/examples/unet/unet_keras_2_tf.h5", null, true);
    }

    /**
     * ResNet50
     */
    @Test
    public void importResnet50() throws Exception {
        importFunctionalModelH5Test(
                "modelimport/keras/examples/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5");
    }

    /**
     * DenseNet
     */
    @Test
    public void importDenseNet() throws Exception {
        importFunctionalModelH5Test("modelimport/keras/examples/densenet/densenet121_tf_keras_2.h5");
    }

    /**
     * SqueezeNet
     */
    @Test
    public void importSqueezeNet() throws Exception {
        importFunctionalModelH5Test("modelimport/keras/examples/squeezenet/squeezenet.h5");
    }

    /**
     * MobileNet
     */
    @Test
    public void importMobileNet() throws Exception {
        ComputationGraph graph = importFunctionalModelH5Test(
                "modelimport/keras/examples/mobilenet/alternative.hdf5");
        INDArray input = Nd4j.ones(10, 3, 299, 299);
        graph.output(input);
    }

    /**
     * InceptionV3 Keras 2 no top
     */
    @Test
    public void importInceptionKeras2() throws Exception {
        int[] inputShape = new int[] { 299, 299, 3 };
        ComputationGraph graph = importFunctionalModelH5Test(
                "modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false);
        INDArray input = Nd4j.ones(10, 3, 299, 299);
        graph.output(input);
        System.out.println(graph.summary());
    }

    /**
     * InceptionV3
     */
    @Test
    @Ignore
    // Takes unreasonably long, but works
    public void importInception() throws Exception {
        ComputationGraph graph = importFunctionalModelH5Test(
                "modelimport/keras/examples/inception/inception_v3_complete.h5");
        INDArray input = Nd4j.ones(10, 3, 299, 299);
        graph.output(input);
        System.out.println(graph.summary());
    }

    /**
     * Inception V4
     */
    @Test
    @Ignore
    // Model and weights have about 170mb, too large for test resources and also too excessive to enable as unit test
    public void importInceptionV4() throws Exception {
        String modelUrl = DL4JResources.getURLString("models/inceptionv4_keras_imagenet_weightsandconfig.h5");
        File kerasFile = testDir.newFile("inceptionv4_keras_imagenet_weightsandconfig.h5");

        if (!kerasFile.exists()) {
            FileUtils.copyURLToFile(new URL(modelUrl), kerasFile);
            kerasFile.deleteOnExit();
        }

        int[] inputShape = new int[] { 299, 299, 3 };
        ComputationGraph graph = importFunctionalModelH5Test(kerasFile.getAbsolutePath(), inputShape, false);

        // System.out.println(graph.summary());

    }

    /**
     * Xception
     */
    @Test
    public void importXception() throws Exception {
        int[] inputShape = new int[] { 299, 299, 3 };
        ComputationGraph graph = importFunctionalModelH5Test(
                "modelimport/keras/examples/xception/xception_tf_keras_2.h5", inputShape, false);
    }

    /**
     * Seq2seq model
     */
    @Test
    @Ignore // does not work yet, needs DL4J enhancements
    public void importSeq2Seq() throws Exception {
        importFunctionalModelH5Test("modelimport/keras/examples/seq2seq/full_model_seq2seq_5549.h5");

    }

    /**
     * Import all AlphaGo Zero model variants, i.e.
     * - Dual residual architecture
     * - Dual convolutional architecture
     * - Separate (policy and value) residual architecture
     * - Separate (policy and value) convolutional architecture
     */
    @Test
    public void importSepConvPolicy() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5");
        INDArray input = Nd4j.create(32, 19, 19, 10);
        model.output(input);
    }

    @Test
    public void importSepResPolicy() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5");
        INDArray input = Nd4j.create(32, 19, 19, 10);
        model.output(input);
    }

    @Test
    public void importSepConvValue() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5");
        INDArray input = Nd4j.create(32, 19, 19, 10);
        model.output(input);
    }

    @Test
    public void importSepResValue() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5");
        INDArray input = Nd4j.create(32, 19, 19, 10);
        model.output(input);
    }

    @Test
    public void importDualRes() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5");
        INDArray input = Nd4j.create(32, 19, 19, 10);
        model.output(input);
    }

    @Test
    public void importDualConv() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5");
        INDArray input = Nd4j.create(32, 19, 19, 10);
        model.output(input);
    }

    /**
     * MTCNN
     */
    @Test
    public void importMTCNN() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/48net_complete.h5");
    }

    @Test
    public void importMTCNN2D() throws Exception {
        ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/12net.h5",
                new int[] { 24, 24, 3 }, false);
        INDArray input = Nd4j.create(10, 3, 24, 24);
        model.output(input);
        System.out.println(model.summary());
    }

    /**
     * Masking layers (simple Masking into LSTM)
     */
    @Test
    public void testMaskingZeroValue() throws Exception {
        MultiLayerNetwork model = importSequentialModelH5Test(
                "modelimport/keras/examples/masking/masking_zero_lstm.h5");
        model.summary();
    }

    @Test
    public void testMaskingTwoValue() throws Exception {
        MultiLayerNetwork model = importSequentialModelH5Test(
                "modelimport/keras/examples/masking/masking_two_lstm.h5");
        model.summary();
    }

    private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
        return importFunctionalModelH5Test(modelPath, null, false);
    }

    private ComputationGraph importFunctionalModelH5Test(String modelPath, int[] inputShape, boolean train)
            throws Exception {
        ClassPathResource modelResource = new ClassPathResource(modelPath,
                KerasModelEndToEndTest.class.getClassLoader());
        File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
        Files.copy(modelResource.getInputStream(), modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
        KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
                .enforceTrainingConfig(train);
        if (inputShape != null) {
            builder.inputShape(inputShape);
        }
        KerasModel model = builder.buildModel();
        return model.getComputationGraph();
    }

    private MultiLayerNetwork importSequentialModelH5Test(String modelPath) throws Exception {
        return importSequentialModelH5Test(modelPath, null);
    }

    private MultiLayerNetwork importSequentialModelH5Test(String modelPath, int[] inputShape) throws Exception {
        ClassPathResource modelResource = new ClassPathResource(modelPath,
                KerasModelEndToEndTest.class.getClassLoader());
        File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
        Files.copy(modelResource.getInputStream(), modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
        KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
                .enforceTrainingConfig(false);
        if (inputShape != null) {
            builder.inputShape(inputShape);
        }
        KerasSequentialModel model = builder.buildSequential();
        return model.getMultiLayerNetwork();
    }

    private void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering,
            boolean checkPredictions) throws Exception {
        importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, false);
    }

    public void importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering,
            boolean checkPredictions, boolean checkGradients) throws Exception {
        ClassPathResource modelResource = new ClassPathResource(modelPath,
                KerasModelEndToEndTest.class.getClassLoader());
        File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
        Files.copy(modelResource.getInputStream(), modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
        KerasSequentialModel kerasModel = new KerasModel().modelBuilder()
                .modelHdf5Filename(modelFile.getAbsolutePath()).enforceTrainingConfig(false).buildSequential();

        MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();

        ClassPathResource outputsResource = new ClassPathResource(inputsOutputsPath,
                KerasModelEndToEndTest.class.getClassLoader());
        File outputsFile = createTempFile(TEMP_OUTPUTS_FILENAME, H5_EXTENSION);
        Files.copy(outputsResource.getInputStream(), outputsFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
        try (Hdf5Archive outputsArchive = new Hdf5Archive(outputsFile.getAbsolutePath())) {

            if (checkPredictions) {
                INDArray input = getInputs(outputsArchive, tfOrdering)[0];
                Map<String, INDArray> activationsKeras = getActivations(outputsArchive, tfOrdering);
                for (int i = 0; i < model.getLayers().length; i++) {
                    String layerName = model.getLayerNames().get(i);
                    if (activationsKeras.containsKey(layerName)) {
                        INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1);
                        if (activationsDl4j.shape().length == 3)
                            activationsDl4j = activationsDl4j.permute(0, 2, 1);
                        compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS);

                    }
                }

                INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0];
                INDArray predictionsDl4j = model.output(input, false);
                compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS);
                INDArray outputs = getOutputs(outputsArchive, true)[0];

                if (outputs.shape()[0] == 1) {
                    outputs = outputs.reshape(outputs.shape()[1], outputs.shape()[0]);
                }

                // FIXME: int cast
                val nOut = (int) outputs.shape()[outputs.shape().length - 1];
                compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS);
            }

            if (checkGradients && !SKIP_GRAD_CHECKS) {
                Random r = new Random(12345);
                INDArray input = getInputs(outputsArchive, tfOrdering)[0];
                INDArray predictionsDl4j = model.output(input, false);

                //Infer one-hot labels... this probably won't work for all
                INDArray testLabels = Nd4j.create(predictionsDl4j.shape());
                if (testLabels.rank() == 2) {
                    for (int i = 0; i < testLabels.size(0); i++) {
                        // FIXME: int cast
                        testLabels.putScalar(i, r.nextInt((int) testLabels.size(1)), 1.0);
                    }
                } else if (testLabels.rank() == 3) {
                    for (int i = 0; i < testLabels.size(0); i++) {
                        for (int j = 0; j < testLabels.size(1); j++) {
                            // FIXME: int cast
                            testLabels.putScalar(i, j, r.nextInt((int) testLabels.size(1)), 1.0);
                        }
                    }
                } else {
                    throw new RuntimeException("Cannot gradient check 4d output array");
                }
                checkGradients(model, input, testLabels);
            }
        }
    }

    private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {
        List<String> inputNames = (List<String>) KerasModelUtils
                .parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_INPUTS)).get(GROUP_ATTR_INPUTS);
        INDArray[] inputs = new INDArray[inputNames.size()];
        for (int i = 0; i < inputNames.size(); i++) {
            inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS);
            if (inputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
                inputs[i] = inputs[i].permute(0, 3, 1, 2);
        }
        return inputs;
    }

    private static Map<String, INDArray> getActivations(Hdf5Archive archive, boolean tensorFlowImageDimOrdering)
            throws Exception {
        Map<String, INDArray> activations = new HashMap<String, INDArray>();
        for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) {
            INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS);
            if (activation.shape().length == 4 && tensorFlowImageDimOrdering)
                activation = activation.permute(0, 3, 1, 2);
            activations.put(layerName, activation);
        }
        return activations;
    }

    private static INDArray[] getOutputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {
        List<String> outputNames = (List<String>) KerasModelUtils
                .parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_OUTPUTS)).get(GROUP_ATTR_OUTPUTS);
        INDArray[] outputs = new INDArray[outputNames.size()];
        for (int i = 0; i < outputNames.size(); i++) {
            outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS);
            if (outputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
                outputs[i] = outputs[i].permute(0, 3, 1, 2);
        }
        return outputs;
    }

    private static INDArray[] getPredictions(Hdf5Archive archive, boolean tensorFlowImageDimOrdering)
            throws Exception {
        List<String> outputNames = (List<String>) KerasModelUtils
                .parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_OUTPUTS)).get(GROUP_ATTR_OUTPUTS);
        INDArray[] predictions = new INDArray[outputNames.size()];
        for (int i = 0; i < outputNames.size(); i++) {
            predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS);
            if (predictions[i].shape().length == 4 && tensorFlowImageDimOrdering)
                predictions[i] = predictions[i].permute(0, 3, 1, 2);
        }
        return predictions;
    }

    private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) {
        INDArray diff = a.sub(b);
        double min = diff.minNumber().doubleValue();
        double max = diff.maxNumber().doubleValue();
        log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max);
        double threshold = 1e-7;
        double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue()));
        double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue()));

        // skip too small absolute inputs
        if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
            assertTrue(a.equalsWithEps(b, eps));
        }

    }

    private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses,
            double eps) {
        ROCMultiClass evalA = new ROCMultiClass(100);
        evalA.eval(target, a);
        double avgAucA = evalA.calculateAverageAUC();
        ROCMultiClass evalB = new ROCMultiClass(100);
        evalB.eval(target, b);
        double avgAucB = evalB.calculateAverageAUC();
        assertEquals(avgAucA, avgAucB, EPS);

        double[] aucA = new double[nbClasses];
        double[] aucB = new double[nbClasses];
        if (nbClasses > 1) {
            for (int i = 0; i < nbClasses; i++) {
                aucA[i] = evalA.calculateAUC(i);
                aucB[i] = evalB.calculateAUC(i);
            }
            assertArrayEquals(aucA, aucB, EPS);
        }
    }

    public static void checkGradients(MultiLayerNetwork net, INDArray input, INDArray labels) {
        double eps = 1e-6;
        double max_rel_error = 1e-3;
        double min_abs_error = 1e-8;

        MultiLayerNetwork netToTest;
        if (net.getOutputLayer() instanceof IOutputLayer) {
            netToTest = net;
        } else {
            org.deeplearning4j.nn.conf.layers.Layer l;
            if (labels.rank() == 2) {
                l = new LossLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
                        .activation(Activation.IDENTITY).build();
            } else {
                //Rank 3
                l = new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
                        .activation(Activation.IDENTITY).nIn(labels.size(1)).nOut(labels.size(1)).build();
            }
            netToTest = new TransferLearning.Builder(net)
                    .fineTuneConfiguration(
                            new FineTuneConfiguration.Builder().updater(new NoOp()).dropOut(0.0).build())
                    .addLayer(l).build();
        }

        log.info("Num params: " + net.numParams());

        for (Layer l : netToTest.getLayers()) {
            // Remove any dropout manually - until this is fixed:
            // https://github.com/deeplearning4j/deeplearning4j/issues/4368
            l.conf().getLayer().setIDropout(null);

            //Also swap out activation functions... this is a bit of a hack, but should make the net gradient checkable...
            if (l.conf().getLayer() instanceof FeedForwardLayer) {
                FeedForwardLayer ffl = (FeedForwardLayer) l.conf().getLayer();
                IActivation activation = ffl.getActivationFn();
                if (activation instanceof ActivationReLU || activation instanceof ActivationLReLU) {
                    ffl.setActivationFn(new ActivationSoftPlus());
                } else if (activation instanceof ActivationHardTanH) {
                    ffl.setActivationFn(new ActivationTanH());
                }
            }
        }

        Nd4j.setDataType(DataBuffer.Type.DOUBLE);
        boolean passed = GradientCheckUtil.checkGradients(netToTest, eps, max_rel_error, min_abs_error, true, false,
                input, labels, null, null, true, 9);
        assertTrue("Gradient check failed", passed);
    }

    private File createTempFile(String prefix, String suffix) throws IOException {
        return testDir.newFile(prefix + "-" + System.nanoTime() + suffix);
    }
}