crbm.CRBM.java Source code

Java tutorial

Introduction

Here is the source code for crbm.CRBM.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package crbm;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.Random;

import org.apache.commons.io.FileUtils;
import org.jblas.FloatMatrix;
import org.jblas.MatrixFunctions;

import javax.imageio.ImageIO;

/**
 *
 * @author Radek
 */
public class CRBM {

    private final Random RANDOM = new Random();

    private final float[][] W;
    private final int filterEdgeLength;

    private float c_k = 0f;
    private float b = 0f;

    public CRBM(int K, int filterEdgeLength) {
        this.W = initW(K, filterEdgeLength);
        this.filterEdgeLength = filterEdgeLength;
    }

    private float[][] initW(int K, int filterEdgeLength) {
        int filterDimensions = filterEdgeLength * filterEdgeLength;
        float[][] result = new float[K][filterDimensions];

        for (int k = 0; k < K; k++) {
            for (int i = 0; i < filterDimensions; i++) {
                result[k][i] = (float) (((RANDOM.nextDouble() - 0.5) * 2) / (filterEdgeLength * filterEdgeLength));
            }
        }
        return result;
    }

    public void train(float[][][] data, int dataEdgeLength, int epochs, float learningRate, String exportPath) {

        for (int e = 0; e < epochs; e++) {
            float error = 0;
            for (float[][] aData : data) {
                error += train(aData, dataEdgeLength, learningRate);
            }
            error /= data.length;
            System.out.println(error);
        }
        //killFirst();

        System.out.println("Training finished");
    }

    public void train(float[][] data, int dataEdgeLength, int epochs, float learningRate, String exportPath) {

        for (int e = 0; e < epochs; e++) {
            float error = 0;
            for (float[] aData : data) {
                error += train(aData, dataEdgeLength, learningRate);
            }
            error /= data.length;
            System.out.println(error);
        }
        //killFirst();

        System.out.println("Training finished");
    }

    public void killFirst() {
        for (int i = 0; i < W[0].length; i++) {
            W[0][i] = 0;
        }
    }

    private float train(float[][] data, int dataEdgeLength, float learningRate) {
        int offset = filterEdgeLength - 1;
        int K = W.length;
        float[][] W1;

        float[][] PH0_k = new float[K][];
        float[][] Grad0_k = new float[K][];
        float[][] H0_k = new float[K][];

        for (int k = 0; k < K; k++) {
            PH0_k[k] = filter(data, W[k], dataEdgeLength, filterEdgeLength, true);
            if (k != 0) {
                addBias(PH0_k[k], c_k);
            }
            PH0_k[k] = logistic(PH0_k[k]);
            Grad0_k[k] = filter(data, PH0_k[k], dataEdgeLength, dataEdgeLength - offset, true);
            H0_k[k] = bernoulli(PH0_k[k]);
        }

        W1 = flip(W);

        float[] V1m = new float[(dataEdgeLength - offset * 2) * (dataEdgeLength - offset * 2)];
        for (int k = 0; k < W1.length; k++) {
            float[] r = filter(data, W1[k], dataEdgeLength - offset, this.filterEdgeLength, true);
            V1m = add(V1m, r);
        }

        addBias(V1m, b);
        float[] V1 = logistic(V1m);
        if (data != null) {
            V1 = concat(data, V1, dataEdgeLength, filterEdgeLength);
        }

        // float[] V1 = getVisible(H0_k, data, dataEdgeLength);

        float[][] PH1_k = new float[K][];
        float[][] Grad1_k = new float[K][];

        for (int k = 0; k < K; k++) {
            PH1_k[k] = filter(V1, W[k], dataEdgeLength, filterEdgeLength);
            if (k != 0) {
                addBias(PH0_k[k], c_k);
            }
            PH1_k[k] = logistic(PH1_k[k]);
            Grad1_k[k] = filter(V1, PH1_k[k], dataEdgeLength, dataEdgeLength - offset);
            W[k] = CD(W[k], Grad0_k[k], Grad1_k[k], learningRate);
        }

        return 0.0f;
    }

    private float train(float[] data, int dataEdgeLength, float learningRate) {
        int offset = filterEdgeLength - 1;
        int K = W.length;
        float[][] W1;

        float[][] PH0_k = new float[K][];
        float[][] Grad0_k = new float[K][];
        float[][] H0_k = new float[K][];

        for (int k = 0; k < K; k++) {
            PH0_k[k] = filter(data, W[k], dataEdgeLength, filterEdgeLength);
            if (k != 0) {
                addBias(PH0_k[k], c_k);
            }
            PH0_k[k] = logistic(PH0_k[k]);
            Grad0_k[k] = filter(data, PH0_k[k], dataEdgeLength, dataEdgeLength - offset);
            H0_k[k] = bernoulli(PH0_k[k]);
        }

        float[] V1 = getVisible(H0_k, data, dataEdgeLength);

        float[][] PH1_k = new float[K][];
        float[][] Grad1_k = new float[K][];

        for (int k = 0; k < K; k++) {
            PH1_k[k] = filter(V1, W[k], dataEdgeLength, filterEdgeLength);
            if (k != 0) {
                addBias(PH0_k[k], c_k);
            }
            PH1_k[k] = logistic(PH1_k[k]);
            Grad1_k[k] = filter(V1, PH1_k[k], dataEdgeLength, dataEdgeLength - offset);
            W[k] = CD(W[k], Grad0_k[k], Grad1_k[k], learningRate);
        }

        FloatMatrix V0M = new FloatMatrix(data);
        FloatMatrix V1M = new FloatMatrix(V1);

        return (float) Math.sqrt(MatrixFunctions.pow(V0M.sub(V1M), 2.0f).sum());
    }

    private float[] filter(float[] data, float[] filter, int dataEdgeLength, int filterEdgeLength) {
        int offset = filterEdgeLength - 1;
        final int rEdgeLength = dataEdgeLength - offset;

        float[] r = new float[rEdgeLength * rEdgeLength];

        for (int y = 0; y < rEdgeLength; y++) {
            for (int x = 0; x < rEdgeLength; x++) {

                float sum = 0;
                for (int yh = 0; yh < filterEdgeLength; yh++) {
                    for (int xh = 0; xh < filterEdgeLength; xh++) {
                        int pos = (y + yh) * dataEdgeLength + x + xh;
                        sum += data[pos] * filter[yh * filterEdgeLength + xh];
                    }
                }

                r[y * rEdgeLength + x] = sum;
            }
        }

        return r;
    }

    private float[] filter(float[][] data, float[] filter, int dataEdgeLength, int filterEdgeLength,
            boolean normalize) {
        int offset = filterEdgeLength - 1;
        final int rEdgeLength = dataEdgeLength - offset;

        float[] r = new float[rEdgeLength * rEdgeLength];

        for (int y = 0; y < rEdgeLength; y++) {
            for (int x = 0; x < rEdgeLength; x++) {

                float sum = 0;
                for (int yh = 0; yh < filterEdgeLength; yh++) {
                    for (int xh = 0; xh < filterEdgeLength; xh++) {
                        int pos = (y + yh) * dataEdgeLength + x + xh;
                        int fPos = yh * filterEdgeLength + xh;
                        for (float[] dataK : data) {
                            sum += dataK[pos] * filter[fPos];
                        }
                    }
                }
                if (normalize) {
                    sum /= (float) data.length;
                }

                r[y * rEdgeLength + x] = sum;
            }
        }
        return r;
    }

    private float[] logistic(float[] data) {
        float[] result = new float[data.length];

        for (int i = 0; i < data.length; i++) {
            result[i] = 1.f / (float) (1. + Math.exp(-data[i]));
        }

        return result;
    }

    private float[] bernoulli(float[] data) {
        float[] r = new float[data.length];

        for (int i = 0; i < data.length; i++) {
            r[i] = (data[i] > RANDOM.nextDouble()) ? 1 : 0;
        }

        return r;
    }

    private float[][] flip(float[][] h) {
        float[][] result = new float[h.length][];

        for (int i = 0; i < h.length; i++) {
            result[i] = flip(h[i]);
        }

        return result;
    }

    private float[] flip(float[] h) {
        float[] result = new float[h.length];

        for (int i = 0; i < h.length; i++) {
            result[i] = h[h.length - i - 1];
        }

        return result;
    }

    private float[] add(float[] data1, float[] data2) {
        float[] result = new float[data1.length];

        for (int i = 0; i < result.length; i++) {
            result[i] = data1[i] + data2[i];
        }

        return result;
    }

    private float[] concat(float[][] V0, float[] V1m, int dataEdgeLength, int filterEdgeLength) {
        float[] result = new float[V0[0].length];

        int offset = filterEdgeLength - 1;

        for (int y = 0; y < dataEdgeLength; y++) {
            for (int x = 0; x < dataEdgeLength; x++) {
                int pos = y * dataEdgeLength + x;
                if (y < offset || x < offset || y >= dataEdgeLength - offset || x >= dataEdgeLength - offset) {
                    for (int k = 0; k < V0.length; k++) {
                        result[pos] += V0[k][pos];
                    }
                    result[pos] /= (float) V0.length;
                } else {
                    int posm = (y - offset) * (dataEdgeLength - 2 * offset) + x - offset;
                    result[pos] = V1m[posm];
                }
            }
        }

        return result;
    }

    private float[] concat(float[] V0, float[] V1m, int dataEdgeLength, int filterEdgeLength) {
        float[] result = new float[V0.length];

        int offset = filterEdgeLength - 1;

        for (int y = 0; y < dataEdgeLength; y++) {
            for (int x = 0; x < dataEdgeLength; x++) {
                int pos = y * dataEdgeLength + x;
                if (y < offset || x < offset || y >= dataEdgeLength - offset || x >= dataEdgeLength - offset) {
                    result[pos] = V0[pos];
                } else {
                    int posm = (y - offset) * (dataEdgeLength - 2 * offset) + x - offset;
                    result[pos] = V1m[posm];
                }
            }
        }

        return result;
    }

    private float[] CD(float[] W_k, float[] Grad0_k, float[] Grad1_k, float learningRate) {
        float[] result = new float[W_k.length];

        for (int i = 0; i < W_k.length; i++) {
            result[i] += W_k[i] + learningRate * (Grad0_k[i] - Grad1_k[i]);
        }

        return result;
    }

    public float[][][] getHidden(float[][][] data, int dataEdgeLength) {
        float[][][] result = new float[data.length][][];
        for (int i = 0; i < data.length; i++) {
            result[i] = getHidden2D(data[i], dataEdgeLength);
        }
        return result;
    }

    public float[][] getHidden2D(float[][] data, int dataEdgeLength) {
        int offset = filterEdgeLength - 1;
        int K = W.length;
        float[][] W1;

        float[][] PH0_k = new float[K][];
        float[][] Grad0_k = new float[K][];
        float[][] H0_k = new float[K][];

        for (int k = 0; k < K; k++) {
            PH0_k[k] = filter(data, W[k], dataEdgeLength, filterEdgeLength, true);
            if (k != 0) {
                addBias(PH0_k[k], c_k);
            }
            PH0_k[k] = logistic(PH0_k[k]);
            Grad0_k[k] = filter(data, PH0_k[k], dataEdgeLength, dataEdgeLength - offset, true);
            H0_k[k] = bernoulli(PH0_k[k]);
        }

        return PH0_k;
    }

    public float[][][] getHidden(float[][] data, int dataEdgeLength) {
        float[][][] result = new float[data.length][][];
        for (int i = 0; i < data.length; i++) {
            result[i] = getHidden(data[i], dataEdgeLength);
        }
        return result;
    }

    public float[][] getHidden(float[] data, int dataEdgeLength) {

        float[][] PH0_k = new float[W.length][];
        float[][] H0_k = new float[W.length][];

        for (int k = 0; k < W.length; k++) {
            PH0_k[k] = filter(data, W[k], dataEdgeLength, filterEdgeLength);
            if (k != 0) {
                addBias(PH0_k[k], c_k);
            }
            PH0_k[k] = logistic(PH0_k[k]);
            H0_k[k] = bernoulli(PH0_k[k]);
        }
        return PH0_k;
    }

    public float[][][] getVisible2D(float[][][] data, float[][] original, int dataEdgeLength) {

        int offset = filterEdgeLength - 1;
        float[][][] result = new float[data.length][W.length][];

        float[][] W1 = flip(W);
        for (int i = 0; i < data.length; i++) {
            for (int k = 0; k < W1.length; k++) {
                result[i][k] = filter(data[i], W1[k], dataEdgeLength - offset, this.filterEdgeLength, true);
            }
        }

        return result;

    }

    public float[][] getVisible(float[][][] data, float[] original, int dataEdgeLength) {
        float[][] result = new float[data.length][];

        for (int i = 0; i < data.length; i++) {
            result[i] = getVisible(data[i], original, dataEdgeLength);
        }

        return result;
    }

    public float[] getVisible(float[][] data, float[] original, int dataEdgeLength) {

        float[][] W1 = flip(W);
        int offset = filterEdgeLength - 1;

        float[] V1m = new float[(dataEdgeLength - offset * 2) * (dataEdgeLength - offset * 2)];
        for (int k = 0; k < W1.length; k++) {
            float[] r = filter(data[k], W1[k], dataEdgeLength - offset, this.filterEdgeLength);
            V1m = add(V1m, r);
        }

        addBias(V1m, b);
        float[] V1 = logistic(V1m);
        if (original != null) {
            V1 = concat(original, V1, dataEdgeLength, filterEdgeLength);
        }

        //
        return V1;
    }

    private void addBias(float[] r, float bias) {
        for (int i = 0; i < r.length; i++) {
            r[i] += bias;
        }
    }

    //    private void exportAsImage(float[][] data, String name, int count) {
    //        for (int k = 0; k < data.length; k++) {
    //            exportAsImage(data[k], name, count, k);
    //        }
    //    }

    //    private void exportAsImage(float[] data, String name, int count){
    //        exportAsImage(data, name, count, 0);
    //    }

    //    private void exportAsImage(float[] data, String name, int count, int k) {
    //        new File(EXPORT_PATH + "/" + name + "/").mkdirs();
    //
    //        BufferedImage image = DataConverter.pixelDataToImage(data, 0.0f, false);
    //        File outputfile = new File(EXPORT_PATH + "/" + name + "/" + count + "_" + k + ".png");
    //        try {
    //            ImageIO.write(image, "png", outputfile);
    //        } catch (IOException ex) {
    //            Logger.getLogger(CRBM.class.getName()).log(Level.SEVERE, null, ex);
    //        }
    //    }

    //    private void print(float[][] W_k) {
    //
    //        for (int k = 0; k < K; k++) {
    //            System.out.println("K:" + k);
    //            for (int i = 0; i < FILTEREDGELENGTH; i++) {
    //                for (int j = 0; j < FILTEREDGELENGTH; j++) {
    //                    System.out.print(W_k[k][i * FILTEREDGELENGTH + j] + " ");
    //                }
    //                System.out.println();
    //            }
    //            System.out.println();
    //        }
    //
    //    }

    private void exportAsImage(float[][] data, String name) {

        String exportPath = "export";

        try {
            FileUtils.deleteDirectory(new File(exportPath));
        } catch (IOException e) {
            e.printStackTrace();
        }

        for (int i = 0; i < data.length; i++) {
            exportAsImage(data[i], name, i);
        }
    }

    private void exportAsImage(float[] data, String name, int count) {
        String exportPath = "export";

        new File(exportPath + "/" + name + "/").mkdirs();

        BufferedImage image = DataConverter.pixelDataToImage(data, 0.0f, false);
        File outputfile = new File(exportPath + "/" + name + "/" + count + ".png");
        try {
            ImageIO.write(image, "png", outputfile);
        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }

}