com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer.java Source code

Java tutorial

Introduction

Here is the source code for com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer.java

Source

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.layers.java;

import com.google.gson.JsonObject;
import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.util.MonitoredItem;
import com.simiacryptus.util.MonitoredObject;
import com.simiacryptus.util.data.PercentileStatistics;
import com.simiacryptus.util.data.ScalarStatistics;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

/**
 * A diagnostic wrapper that collects performance data and, if enabled, value statistics of output and backprop
 * signals.
 */
@SuppressWarnings({ "serial", "FieldCanBeLocal" })
public final class MonitoringWrapperLayer extends WrapperLayer implements MonitoredItem {

    private final PercentileStatistics backwardPerformance = new PercentileStatistics();
    private final ScalarStatistics backwardSignal = new PercentileStatistics();
    private final PercentileStatistics forwardPerformance = new PercentileStatistics();
    private final ScalarStatistics forwardSignal = new PercentileStatistics();
    private final boolean verbose = false;
    private boolean recordSignalMetrics = false;
    private int totalBatches = 0;
    private int totalItems = 0;

    /**
     * Instantiates a new Monitoring wrapper key.
     *
     * @param json the json
     * @param rs   the rs
     */
    protected MonitoringWrapperLayer(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        super(json, rs);
        if (json.has("forwardPerf")) {
            forwardPerformance.readJson(json.getAsJsonObject("forwardPerf"));
        }
        if (json.has("backwardPerf")) {
            backwardPerformance.readJson(json.getAsJsonObject("backwardPerf"));
        }
        if (json.has("backpropStatistics")) {
            backwardSignal.readJson(json.getAsJsonObject("backpropStatistics"));
        }
        if (json.has("outputStatistics")) {
            forwardSignal.readJson(json.getAsJsonObject("outputStatistics"));
        }
        recordSignalMetrics = json.get("recordSignalMetrics").getAsBoolean();
        totalBatches = json.get("totalBatches").getAsInt();
        totalItems = json.get("totalItems").getAsInt();
    }

    /**
     * Instantiates a new Monitoring wrapper key.
     *
     * @param inner the heapCopy
     */
    public MonitoringWrapperLayer(final Layer inner) {
        super(inner);
    }

    /**
     * From json monitoring wrapper key.
     *
     * @param json the json
     * @param rs   the rs
     * @return the monitoring wrapper key
     */
    public static MonitoringWrapperLayer fromJson(@Nonnull final JsonObject json, Map<CharSequence, byte[]> rs) {
        return new MonitoringWrapperLayer(json, rs);
    }

    /**
     * Add to monitoring wrapper key.
     *
     * @param obj the obj
     * @return the monitoring wrapper key
     */
    @Nonnull
    public MonitoringWrapperLayer addTo(@Nonnull final MonitoredObject obj) {
        return addTo(obj, getInner().getName());
    }

    /**
     * Add to monitoring wrapper key.
     *
     * @param obj  the obj
     * @param name the name
     * @return the monitoring wrapper key
     */
    @Nonnull
    public MonitoringWrapperLayer addTo(@Nonnull final MonitoredObject obj, final String name) {
        setName(name);
        obj.addObj(getName(), this);
        return this;
    }

    @Override
    public Result evalAndFree(@Nonnull final Result... inObj) {
        @Nonnull
        final AtomicLong passbackNanos = new AtomicLong(0);
        final Result[] wrappedInput = Arrays.stream(inObj).map(result -> {
            return new Result(result.getData(),
                    (@Nonnull final DeltaSet<UUID> buffer, @Nonnull final TensorList data) -> {
                        data.addRef();
                        passbackNanos.addAndGet(TimedResult.time(() -> result.accumulate(buffer, data)).timeNanos);
                    }) {

                @Override
                protected void _free() {
                    result.freeRef();
                }

                @Override
                public boolean isAlive() {
                    return result.isAlive();
                }
            };
        }).toArray(i -> new Result[i]);
        @Nonnull
        TimedResult<Result> timedResult = TimedResult.time(() -> getInner().evalAndFree(wrappedInput));
        final Result output = timedResult.result;
        forwardPerformance.add((timedResult.timeNanos) / 1000000000.0);
        totalBatches++;
        final int items = Arrays.stream(inObj).mapToInt(x -> x.getData().length()).max().orElse(1);
        totalItems += items;
        if (recordSignalMetrics) {
            forwardSignal.clear();
            output.getData().stream().parallel().forEach(t -> {
                forwardSignal.add(t.getData());
                t.freeRef();
            });
        }
        return new Result(output.getData(),
                (@Nonnull final DeltaSet<UUID> buffer, @Nonnull final TensorList data) -> {
                    if (recordSignalMetrics) {
                        backwardSignal.clear();
                        data.stream().parallel().forEach(t -> {
                            backwardSignal.add(t.getData());
                            t.freeRef();
                        });
                    }
                    data.addRef();
                    backwardPerformance.add((TimedResult.time(() -> output.accumulate(buffer, data)).timeNanos
                            - passbackNanos.getAndSet(0)) / (items * 1e9));
                }) {

            @Override
            protected void _free() {
                output.freeRef();
            }

            @Override
            public boolean isAlive() {
                return output.isAlive();
            }
        };
    }

    /**
     * Gets backward performance.
     *
     * @return the backward performance
     */
    @Nonnull
    public PercentileStatistics getBackwardPerformance() {
        return backwardPerformance;
    }

    /**
     * Gets backward signal.
     *
     * @return the backward signal
     */
    @Nonnull
    public ScalarStatistics getBackwardSignal() {
        return backwardSignal;
    }

    /**
     * Gets forward performance.
     *
     * @return the forward performance
     */
    @Nonnull
    public PercentileStatistics getForwardPerformance() {
        return forwardPerformance;
    }

    /**
     * Gets forward signal.
     *
     * @return the forward signal
     */
    @Nonnull
    public ScalarStatistics getForwardSignal() {
        return forwardSignal;
    }

    @Nonnull
    @Override
    public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
        @Nonnull
        final JsonObject json = super.getJson(resources, dataSerializer);
        //json.fn("forwardPerf",forwardPerf.getJson());
        //json.fn("backwardPerf",backwardPerf.getJson());
        json.addProperty("totalBatches", totalBatches);
        json.addProperty("totalItems", totalItems);
        json.addProperty("recordSignalMetrics", recordSignalMetrics);
        return json;
    }

    @Nonnull
    @Override
    public Map<CharSequence, Object> getMetrics() {
        @Nonnull
        final HashMap<CharSequence, Object> map = new HashMap<>();
        map.put("class", getInner().getClass().getName());
        map.put("totalBatches", totalBatches);
        map.put("totalItems", totalItems);
        map.put("outputStatistics", forwardSignal.getMetrics());
        map.put("backpropStatistics", backwardSignal.getMetrics());
        if (verbose) {
            map.put("forwardPerformance", forwardPerformance.getMetrics());
            map.put("backwardPerformance", backwardPerformance.getMetrics());
        }
        final double batchesPerItem = totalBatches * 1.0 / totalItems;
        map.put("avgMsPerItem", 1000 * batchesPerItem * forwardPerformance.getMean());
        map.put("medianMsPerItem", 1000 * batchesPerItem * forwardPerformance.getPercentile(0.5));
        final double backpropMean = backwardPerformance.getMean();
        final double backpropMedian = backwardPerformance.getPercentile(0.5);
        map.put("avgMsPerItem_Backward", 1000 * batchesPerItem * backpropMean);
        map.put("medianMsPerItem_Backward", 1000 * batchesPerItem * backpropMedian);
        @Nullable
        final List<double[]> state = state();
        @Nonnull
        final ScalarStatistics statistics = new PercentileStatistics();
        for (@Nonnull
        final double[] s : state) {
            for (final double v : s) {
                statistics.add(v);
            }
        }
        if (statistics.getCount() > 0) {
            @Nonnull
            final HashMap<CharSequence, Object> weightStats = new HashMap<>();
            weightStats.put("buffers", state.size());
            weightStats.putAll(statistics.getMetrics());
            map.put("weights", weightStats);
        }
        return map;
    }

    @Nullable
    @Override
    public String getName() {
        return getInner().getName();
    }

    /**
     * Record signal metrics boolean.
     *
     * @return the boolean
     */
    public boolean recordSignalMetrics() {
        return recordSignalMetrics;
    }

    @Nonnull
    @Override
    public Layer setName(final String name) {
        if (null != getInner()) {
            getInner().setName(name);
        }
        return this;
    }

    /**
     * Should record signal metrics monitoring wrapper key.
     *
     * @param recordSignalMetrics the record signal metrics
     * @return the monitoring wrapper key
     */
    @Nonnull
    public MonitoringWrapperLayer shouldRecordSignalMetrics(final boolean recordSignalMetrics) {
        this.recordSignalMetrics = recordSignalMetrics;
        return this;
    }
}