Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.util; import com.anhth12.distributions.Distributions; import com.anhth12.weights.WeightInit; import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.random.MersenneTwister; import org.nd4j.linalg.api.activation.ActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; /** * * @author anhth12 */ public class WeightInitUtil { public static INDArray initWeight(int[] shape, WeightInit initScheme, ActivationFunction act, RealDistribution dist) { INDArray ret; switch (initScheme) { case VI: //Rand .* 2 .*r - r ret = Nd4j.rand(shape); int len = 0; for (int i = 0; i < shape.length; i++) { len += shape[i]; } double r = Math.sqrt(6) / Math.sqrt(len + 1); ret.muli(2).muli(r).subi(r); return ret; case ZERO: return Nd4j.create(shape); case SIZE: return uniformBasedOnInAndOut(shape, shape[0], shape[1]); case DISTRIBUTION: ret = Nd4j.rand(shape); for (int i = 0; i < ret.slices(); i++) { ret.putSlice(i, Nd4j.create(dist.sample(ret.columns()))); } case NORMALIZED: ret = Nd4j.rand(shape); return ret.subi(0.5).divi(shape[0]); case UNIFORM: double a = 1 / shape[0]; return Nd4j.rand(shape, -a, a, new MersenneTwister(123)); default: throw new AssertionError(initScheme.name()); } } public static INDArray initWeights(int nIn, int nOut, WeightInit initScheme, ActivationFunction act, RealDistribution dist) { return initWeight(new int[] { nIn, nOut }, initScheme, act, dist); } public static INDArray uniformBasedOnInAndOut(int[] shape, int nIn, int nOut) { double min = -4.0 * Math.sqrt(6.0 / (double) (nOut + nIn)); double max = 4.0 * Math.sqrt(6.0 / (double) (nOut + nIn)); return Nd4j.rand(shape, Distributions.uniform(new MersenneTwister(123), min, max)); } public static INDArray initWeights(int[] shape, float min, float max) { return Nd4j.rand(shape, min, max, new MersenneTwister(123)); } public static INDArray normalized(int[] shape, int nIn) { return Nd4j.rand(shape).subi(0.5).divi((double) nIn); } }