com.anhth12.util.WeightInitUtil.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.util.WeightInitUtil.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.util;

import com.anhth12.distributions.Distributions;
import com.anhth12.weights.WeightInit;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/**
 *
 * @author anhth12
 */
public class WeightInitUtil {

    public static INDArray initWeight(int[] shape, WeightInit initScheme, ActivationFunction act,
            RealDistribution dist) {
        INDArray ret;
        switch (initScheme) {
        case VI:
            //Rand .* 2 .*r - r
            ret = Nd4j.rand(shape);
            int len = 0;
            for (int i = 0; i < shape.length; i++) {
                len += shape[i];
            }
            double r = Math.sqrt(6) / Math.sqrt(len + 1);
            ret.muli(2).muli(r).subi(r);
            return ret;
        case ZERO:
            return Nd4j.create(shape);
        case SIZE:
            return uniformBasedOnInAndOut(shape, shape[0], shape[1]);
        case DISTRIBUTION:
            ret = Nd4j.rand(shape);
            for (int i = 0; i < ret.slices(); i++) {
                ret.putSlice(i, Nd4j.create(dist.sample(ret.columns())));
            }
        case NORMALIZED:
            ret = Nd4j.rand(shape);
            return ret.subi(0.5).divi(shape[0]);
        case UNIFORM:
            double a = 1 / shape[0];
            return Nd4j.rand(shape, -a, a, new MersenneTwister(123));
        default:
            throw new AssertionError(initScheme.name());
        }
    }

    public static INDArray initWeights(int nIn, int nOut, WeightInit initScheme, ActivationFunction act,
            RealDistribution dist) {
        return initWeight(new int[] { nIn, nOut }, initScheme, act, dist);
    }

    public static INDArray uniformBasedOnInAndOut(int[] shape, int nIn, int nOut) {
        double min = -4.0 * Math.sqrt(6.0 / (double) (nOut + nIn));
        double max = 4.0 * Math.sqrt(6.0 / (double) (nOut + nIn));
        return Nd4j.rand(shape, Distributions.uniform(new MersenneTwister(123), min, max));
    }

    public static INDArray initWeights(int[] shape, float min, float max) {
        return Nd4j.rand(shape, min, max, new MersenneTwister(123));
    }

    public static INDArray normalized(int[] shape, int nIn) {
        return Nd4j.rand(shape).subi(0.5).divi((double) nIn);
    }

}