com.anhth12.models.featuredetectors.rbm.run.App.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.models.featuredetectors.rbm.run.App.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.models.featuredetectors.rbm.run;

import com.anhth12.models.featuredetectors.rbm.RBM;
import com.anhth12.nn.api.LayerFactory;
import com.anhth12.nn.conf.NeuralNetworkConfiguration;
import com.anhth12.nn.layers.factory.LayerFactories;
import com.anhth12.weights.WeightInit;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.activation.Sigmoid;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
 *
 * @author anhth12
 */
public class App {

    public static void main(String[] args) {
        try {
            String path = "D:\\1.Source\\LogisticRegression\\src\\resources\\mnist_784_1000.txt";
            FilleDatsetIterator dsi = new FilleDatsetIterator(path);
            DataSet d = dsi.next();
            RandomGenerator rng = new MersenneTwister(123);

            NeuralNetworkConfiguration conf = new NeuralNetworkConfiguration();
            conf.setWeightInit(WeightInit.VI);
            conf.setDropOut(0.3f);
            conf.setHiddenUnit(RBM.HiddenUnit.SOFTMAX);
            conf.setVisibleUnit(RBM.VisibleUnit.SOFTMAX);
            conf.setLossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY);
            conf.setRng(rng);
            conf.setnIn(d.numInputs());
            conf.setnOut(d.numOutcomes());
            conf.setNumIterations(10);
            conf.setActivationFunction(new Sigmoid());

            LayerFactory rbmFactory = LayerFactories.getLayerFactory(RBM.class);
            RBM rbm = rbmFactory.create(conf);
            //            RBM rbm = new RBM(conf);
            rbm.fit(d.getFeatureMatrix());

        } catch (Exception e) {
            e.printStackTrace();
        }

    }
}