com.simiacryptus.mindseye.test.unit.SerializationTest.java Source code

Java tutorial

Introduction

Here is the source code for com.simiacryptus.mindseye.test.unit.SerializationTest.java

Source

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.test.unit;

import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.SerialPrecision;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.test.ToleranceStatistics;
import com.simiacryptus.notebook.NotebookOutput;
import com.simiacryptus.util.Util;
import org.apache.commons.io.IOUtils;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.zip.GZIPOutputStream;
import java.util.zip.ZipException;
import java.util.zip.ZipFile;

/**
 * The type Json apply.
 */
public class SerializationTest extends ComponentTestBase<ToleranceStatistics> {
    @Nonnull
    private final HashMap<SerialPrecision, Layer> models = new HashMap<>();
    private boolean persist = false;

    /**
     * Compress gz byte [ ].
     *
     * @param prettyPrint the pretty print
     * @return the byte [ ]
     */
    public static byte[] compressGZ(@Nonnull String prettyPrint) {
        return compressGZ(prettyPrint.getBytes(Charset.forName("UTF-8")));
    }

    /**
     * Compress gz byte [ ].
     *
     * @param bytes the bytes
     * @return the byte [ ]
     */
    public static byte[] compressGZ(byte[] bytes) {
        @Nonnull
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            try (@Nonnull
            GZIPOutputStream out = new GZIPOutputStream(byteArrayOutputStream)) {
                IOUtils.write(bytes, out);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return byteArrayOutputStream.toByteArray();
    }

    @Nullable
    @Override
    public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer layer,
            final Tensor... inputPrototype) {
        log.h1("Serialization");
        log.p("This apply will demonstrate the key's JSON serialization, and verify deserialization integrity.");

        String prettyPrint = "";
        log.h2("Raw Json");
        try {
            prettyPrint = log.eval(() -> {
                final JsonObject json = layer.getJson();
                @Nonnull
                final Layer echo = Layer.fromJson(json);
                if (echo == null)
                    throw new AssertionError("Failed to deserialize");
                if (layer == echo)
                    throw new AssertionError("Serialization did not copy");
                if (!layer.equals(echo))
                    throw new AssertionError("Serialization not equal");
                echo.freeRef();
                return new GsonBuilder().setPrettyPrinting().create().toJson(json);
            });
            @Nonnull
            String filename = layer.getClass().getSimpleName() + "_" + log.getName() + ".json";
            log.p(log.file(prettyPrint, filename,
                    String.format("Wrote Model to %s; %s characters", filename, prettyPrint.length())));
        } catch (RuntimeException e) {
            e.printStackTrace();
            Util.sleep(1000);
        } catch (OutOfMemoryError e) {
            e.printStackTrace();
            Util.sleep(1000);
        }
        log.p("");
        @Nonnull
        Object outSync = new Object();
        if (prettyPrint.isEmpty() || prettyPrint.length() > 1024 * 64)
            Arrays.stream(SerialPrecision.values()).parallel().forEach(precision -> {
                try {
                    @Nonnull
                    File file = new File(log.getResourceDir(), log.getName() + "_" + precision.name() + ".zip");
                    layer.writeZip(file, precision);
                    @Nonnull
                    final Layer echo = Layer.fromZip(new ZipFile(file));
                    getModels().put(precision, echo);
                    synchronized (outSync) {
                        log.h2(String.format("Zipfile %s", precision.name()));
                        log.p(log.link(file, String.format("Wrote Model apply %s precision to %s; %.3fMiB bytes",
                                precision, file.getName(), file.length() * 1.0 / (0x100000))));
                    }
                    if (!isPersist())
                        file.delete();
                    if (echo == null)
                        throw new AssertionError("Failed to deserialize");
                    if (layer == echo)
                        throw new AssertionError("Serialization did not copy");
                    if (!layer.equals(echo))
                        throw new AssertionError("Serialization not equal");
                } catch (RuntimeException e) {
                    e.printStackTrace();
                } catch (OutOfMemoryError e) {
                    e.printStackTrace();
                } catch (ZipException e) {
                    e.printStackTrace();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            });

        return null;
    }

    /**
     * Gets models.
     *
     * @return the models
     */
    @Nonnull
    public HashMap<SerialPrecision, Layer> getModels() {
        return models;
    }

    /**
     * Is persist boolean.
     *
     * @return the boolean
     */
    public boolean isPersist() {
        return persist;
    }

    /**
     * Sets persist.
     *
     * @param persist the persist
     * @return the persist
     */
    @Nonnull
    public SerializationTest setPersist(boolean persist) {
        this.persist = persist;
        return this;
    }

    @Nonnull
    @Override
    public String toString() {
        return "SerializationTest{" + "models=" + models + ", persist=" + persist + '}';
    }
}