Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.models.featuredetectors.rbm.run; import com.anhth12.models.featuredetectors.rbm.RBM; import com.anhth12.nn.api.LayerFactory; import com.anhth12.nn.conf.NeuralNetworkConfiguration; import com.anhth12.nn.layers.factory.LayerFactories; import com.anhth12.weights.WeightInit; import org.apache.commons.math3.random.MersenneTwister; import org.apache.commons.math3.random.RandomGenerator; import org.nd4j.linalg.api.activation.ActivationFunction; import org.nd4j.linalg.api.activation.Sigmoid; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions; /** * * @author anhth12 */ public class App { public static void main(String[] args) { try { String path = "D:\\1.Source\\LogisticRegression\\src\\resources\\mnist_784_1000.txt"; FilleDatsetIterator dsi = new FilleDatsetIterator(path); DataSet d = dsi.next(); RandomGenerator rng = new MersenneTwister(123); NeuralNetworkConfiguration conf = new NeuralNetworkConfiguration(); conf.setWeightInit(WeightInit.VI); conf.setDropOut(0.3f); conf.setHiddenUnit(RBM.HiddenUnit.SOFTMAX); conf.setVisibleUnit(RBM.VisibleUnit.SOFTMAX); conf.setLossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY); conf.setRng(rng); conf.setnIn(d.numInputs()); conf.setnOut(d.numOutcomes()); conf.setNumIterations(10); conf.setActivationFunction(new Sigmoid()); LayerFactory rbmFactory = LayerFactories.getLayerFactory(RBM.class); RBM rbm = rbmFactory.create(conf); // RBM rbm = new RBM(conf); rbm.fit(d.getFeatureMatrix()); } catch (Exception e) { e.printStackTrace(); } } }