org.deeplearning4j.nn.modelimport.keras.e2e.KerasCustomLayerTest.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.nn.modelimport.keras.e2e.KerasCustomLayerTest.java

Source

/*-
 *
 *  * Copyright 2017 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */
package org.deeplearning4j.nn.modelimport.keras.e2e;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasLRN;
import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasPoolHelper;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.Test;

import java.io.File;
import java.net.URL;

/**
 * Test import of Keras custom layers. Must be run manually, since user must download weights and config from
 * http://blob.deeplearning4j.org/models/googlenet_keras_weights.h5
 * http://blob.deeplearning4j.org/models/googlenet_config.json
 *
 * @author Justin Long (crockpotveggies)
 */
@Slf4j
public class KerasCustomLayerTest {

    // run manually, might take a long time to load (too long for unit tests)
    // @Test
    public void testCustomLayerImport() throws Exception {
        // file paths
        String kerasWeightsAndConfigUrl = "http://blob.deeplearning4j.org/models/googlenet_keras_weightsandconfig.h5";
        File cachedKerasFile = new File(System.getProperty("java.io.tmpdir"),
                "googlenet_keras_weightsandconfig.h5");
        String outputPath = System.getProperty("java.io.tmpdir") + "/googlenet_dl4j_inference.zip";

        KerasLayer.registerCustomLayer("PoolHelper", KerasPoolHelper.class);
        KerasLayer.registerCustomLayer("LRN", KerasLRN.class);

        // download file
        if (!cachedKerasFile.exists()) {
            log.info("Downloading model to " + cachedKerasFile.toString());
            FileUtils.copyURLToFile(new URL(kerasWeightsAndConfigUrl), cachedKerasFile);
            cachedKerasFile.deleteOnExit();
        }

        org.deeplearning4j.nn.api.Model importedModel = KerasModelImport
                .importKerasModelAndWeights(cachedKerasFile.getAbsolutePath());
        ModelSerializer.writeModel(importedModel, outputPath, false);

        ComputationGraph serializedModel = ModelSerializer.restoreComputationGraph(outputPath);
        log.info(serializedModel.summary());
    }
}