org.nd4j.linalg.Nd4jTestsComparisonFortran.java Source code

Java tutorial

Introduction

Here is the source code for org.nd4j.linalg.Nd4jTestsComparisonFortran.java

Source

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg;

import static org.junit.Assert.*;

import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.Pair;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.CheckUtil;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Random;

import static org.junit.Assert.assertArrayEquals;

/**
 * Tests comparing Nd4j ops to other libraries
 */
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
    private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class);

    public static final int SEED = 123;

    DataBuffer.Type initialType;

    public Nd4jTestsComparisonFortran(Nd4jBackend backend) {
        super(backend);
        this.initialType = Nd4j.dataType();
    }

    @Before
    public void before() throws Exception {
        super.before();
        DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
        Nd4j.getRandom().setSeed(SEED);

    }

    @After
    public void after() throws Exception {
        super.after();
        DataTypeUtil.setDTypeForContext(initialType);
    }

    @Override
    public char ordering() {
        return 'f';
    }

    @Test
    public void testMmulWithOpsCommonsMath() {
        List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
        List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED);

        for (int i = 0; i < first.size(); i++) {
            for (int j = 0; j < second.size(); j++) {
                Pair<INDArray, String> p1 = first.get(i);
                Pair<INDArray, String> p2 = second.get(j);
                String errorMsg = getTestWithOpsErrorMsg(i, j, "mmul", p1, p2);
                assertTrue(errorMsg, CheckUtil.checkMmul(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
            }
        }
    }

    @Test
    public void testGemmWithOpsCommonsMath() {
        List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
        List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED);
        List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED);
        List<Pair<INDArray, String>> secondT = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, SEED);
        double[] alpha = { 1.0, -0.5, 2.5 };
        double[] beta = { 0.0, -0.25, 1.5 };
        INDArray cOrig = Nd4j.create(new int[] { 3, 4 });
        Random r = new Random(12345);
        for (int i = 0; i < cOrig.size(0); i++) {
            for (int j = 0; j < cOrig.size(1); j++) {
                cOrig.putScalar(new int[] { i, j }, r.nextDouble());
            }
        }

        for (int i = 0; i < first.size(); i++) {
            for (int j = 0; j < second.size(); j++) {
                for (int k = 0; k < alpha.length; k++) {
                    for (int m = 0; m < beta.length; m++) {
                        System.out.println((String.format("Running iteration %d %d %d %d", i, j, k, m)));

                        INDArray cff = Nd4j.create(cOrig.shape(), 'f');
                        cff.assign(cOrig);
                        INDArray cft = Nd4j.create(cOrig.shape(), 'f');
                        cft.assign(cOrig);
                        INDArray ctf = Nd4j.create(cOrig.shape(), 'f');
                        ctf.assign(cOrig);
                        INDArray ctt = Nd4j.create(cOrig.shape(), 'f');
                        ctt.assign(cOrig);

                        double a = alpha[k];
                        double b = beta[k];
                        Pair<INDArray, String> p1 = first.get(i);
                        Pair<INDArray, String> p1T = firstT.get(i);
                        Pair<INDArray, String> p2 = second.get(j);
                        Pair<INDArray, String> p2T = secondT.get(j);
                        String errorMsgff = getGemmErrorMsg(i, j, false, false, a, b, p1, p2);
                        String errorMsgft = getGemmErrorMsg(i, j, false, true, a, b, p1, p2T);
                        String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2);
                        String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T);

                        assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false,
                                a, b, 1e-4, 1e-6));
                        assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true,
                                a, b, 1e-4, 1e-6));
                        assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false,
                                a, b, 1e-4, 1e-6));
                        assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true,
                                a, b, 1e-4, 1e-6));
                    }
                }
            }
        }
    }

    @Test
    public void testGemvApacheCommons() {

        int[] rowsArr = new int[] { 4, 4, 4, 8, 8, 8 };
        int[] colsArr = new int[] { 2, 1, 10, 2, 1, 10 };

        for (int x = 0; x < rowsArr.length; x++) {
            int rows = rowsArr[x];
            int cols = colsArr[x];

            List<Pair<INDArray, String>> matrices = NDArrayCreationUtil.getAllTestMatricesWithShape(rows, cols,
                    12345);
            List<Pair<INDArray, String>> vectors = NDArrayCreationUtil.getAllTestMatricesWithShape(cols, 1, 12345);

            for (int i = 0; i < matrices.size(); i++) {
                for (int j = 0; j < vectors.size(); j++) {

                    Pair<INDArray, String> p1 = matrices.get(i);
                    Pair<INDArray, String> p2 = vectors.get(j);
                    String errorMsg = getTestWithOpsErrorMsg(i, j, "mmul", p1, p2);

                    INDArray m = p1.getFirst();
                    INDArray v = p2.getFirst();

                    RealMatrix rm = new BlockRealMatrix(m.rows(), m.columns());
                    for (int r = 0; r < m.rows(); r++) {
                        for (int c = 0; c < m.columns(); c++) {
                            double d = m.getDouble(r, c);
                            rm.setEntry(r, c, d);
                        }
                    }

                    RealMatrix rv = new BlockRealMatrix(cols, 1);
                    for (int r = 0; r < v.rows(); r++) {
                        double d = v.getDouble(r, 0);
                        rv.setEntry(r, 0, d);
                    }

                    INDArray gemv = m.mmul(v);
                    RealMatrix gemv2 = rm.multiply(rv);

                    assertArrayEquals(new int[] { rows, 1 }, gemv.shape());
                    assertArrayEquals(new int[] { rows, 1 },
                            new int[] { gemv2.getRowDimension(), gemv2.getColumnDimension() });

                    //Check entries:
                    for (int r = 0; r < rows; r++) {
                        double exp = gemv2.getEntry(r, 0);
                        double act = gemv.getDouble(r, 0);
                        assertEquals(errorMsg, exp, act, 1e-5);
                    }
                }
            }
        }
    }

    @Test
    public void testAddSubtractWithOpsCommonsMath() {
        List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
        List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
        for (int i = 0; i < first.size(); i++) {
            for (int j = 0; j < second.size(); j++) {
                Pair<INDArray, String> p1 = first.get(i);
                Pair<INDArray, String> p2 = second.get(j);
                String errorMsg1 = getTestWithOpsErrorMsg(i, j, "add", p1, p2);
                String errorMsg2 = getTestWithOpsErrorMsg(i, j, "sub", p1, p2);
                boolean addFail = CheckUtil.checkAdd(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
                assertTrue(errorMsg1, addFail);
                boolean subFail = CheckUtil.checkSubtract(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
                assertTrue(errorMsg2, subFail);
            }
        }
    }

    @Test
    public void testMulDivOnCheckUtilMatrices() {
        List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
        List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED);
        for (int i = 0; i < first.size(); i++) {
            for (int j = 0; j < second.size(); j++) {
                Pair<INDArray, String> p1 = first.get(i);
                Pair<INDArray, String> p2 = second.get(j);
                String errorMsg1 = getTestWithOpsErrorMsg(i, j, "mul", p1, p2);
                String errorMsg2 = getTestWithOpsErrorMsg(i, j, "div", p1, p2);
                assertTrue(errorMsg1, CheckUtil.checkMulManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
                assertTrue(errorMsg2, CheckUtil.checkDivManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
            }
        }
    }

    private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
            Pair<INDArray, String> second) {
        return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
    }

    private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
            double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
        return i + "," + j + " - gemm(tA=" + transposeA + ",tB= " + transposeB + ",alpha=" + alpha + ",beta= "
                + beta + "). A=" + first.getSecond() + ", B=" + second.getSecond();
    }
}