jcuda.jcublas.kernel.TestMatrixOperations.java Source code

Java tutorial

Introduction

Here is the source code for jcuda.jcublas.kernel.TestMatrixOperations.java

Source

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package jcuda.jcublas.kernel;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.lang3.time.StopWatch;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.broadcast.*;
import org.nd4j.linalg.api.ops.impl.transforms.Log;
import org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.executors.ExecutorServiceProvider;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;

import static org.junit.Assert.*;

public class TestMatrixOperations {

    @Test
    public void testDot() {
        INDArray four = Nd4j.linspace(1, 4, 4);
        double dot = Nd4j.getBlasWrapper().dot(four, four);
        assertEquals(30, dot, 1e-1);
    }

    @Test
    public void testSums() {
        INDArray a = Nd4j.linspace(1, 4, 4).reshape(2, 2);
        INDArray tad = a.tensorAlongDimension(1, 0);
        INDArray tadOne = a.tensorAlongDimension(1, 1);
        int ele = tad.elementWiseStride();
        int otherEle = tadOne.elementWiseStride();
        //assertEquals(Nd4j.create(new float[]{4, 6}), a.sum(0));
        assertEquals(Nd4j.create(new float[] { 3, 7 }), a.sum(1));
        assertEquals(10, a.sumNumber().doubleValue(), 1e-1);

    }

    @Test
    public void testMeans() {
        INDArray a = Nd4j.linspace(1, 4, 4).reshape(2, 2);
        INDArray mean1 = a.mean(1);
        assertEquals(Nd4j.create(new double[] { 1.5, 3.5 }), mean1);
        assertEquals(Nd4j.create(new double[] { 2, 3 }), a.mean(0));
        assertEquals(2.5, Nd4j.linspace(1, 4, 4).meanNumber().doubleValue(), 1e-1);
        assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1);

    }

    @Test
    public void testTad() {
        INDArray arr = Nd4j.ones(2, 10, 10, 10, 10);
        for (int i = 0; i < 5; i++) {
            System.out.println(arr.tensorAlongDimension(i, 1).offset());
        }
    }

    @Test
    public void testSumWithRow2() {
        //All sums in this method execute without exceptions.
        INDArray array3d = Nd4j.ones(2, 10, 10);
        array3d.sum(0);
        array3d.sum(1);
        array3d.sum(2);

        INDArray array4d = Nd4j.ones(2, 10, 10, 10);

        int tad = array4d.tensorAlongDimension(0, 0).elementWiseStride();
        int tads = array4d.tensorssAlongDimension(0);
        for (int i = 10; i < array4d.tensorssAlongDimension(0); i++) {
            System.out.println(array4d.tensorAlongDimension(i, 0).offset());
        }
        array4d.sum(0);
        array4d.sum(1);
        array4d.sum(2);
        array4d.sum(3);

        INDArray array5d = Nd4j.ones(2, 10, 10, 10, 10);
        array5d.sum(0);
        array5d.sum(1);
        array5d.sum(2);
        array5d.sum(3);
        array5d.sum(4);
    }

    @Test
    public void testEps() {
        INDArray ones = Nd4j.ones(5);
        INDArray eps = Nd4j.getExecutioner().exec(new Eps(ones, ones, ones, ones.length())).z();
        double sum = eps.sumNumber().doubleValue();
        assertEquals(5, sum, 1e-1);
    }

    @Test
    public void testMean() {
        INDArray mean2 = Nd4j.linspace(1, 5, 5);
        assertEquals(3, mean2.meanNumber().doubleValue(), 1e-1);
    }

    @Test
    public void testBlasSum() {
        INDArray arr = Nd4j.linspace(1, 4, 4);
        double sum = Nd4j.getBlasWrapper().asum(arr);
        assertEquals(10, sum, 1e-1);
    }

    @Test
    public void testSum2() {
        INDArray test = Nd4j.create(new float[] { 1, 2, 3, 4 }, new int[] { 2, 2 });
        INDArray sum = test.sum(1);
        INDArray assertion = Nd4j.create(new float[] { 3, 7 });
        assertEquals(assertion, sum);
        INDArray sum0 = Nd4j.create(new double[] { 4, 6 });
        assertEquals(sum0, test.sum(0));
    }

    @Test
    public void testRowSoftmax() {
        OpExecutioner opExecutioner = Nd4j.getExecutioner();
        INDArray arr = Nd4j.linspace(1, 6, 6);
        SoftMax softMax = new SoftMax(arr);
        opExecutioner.exec(softMax);
        assertEquals(1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
    }

    @Test
    public void testRowLogSoftMax() {
        //For moderate input values, LogSoftMax op should be identical to log(softmax)
        // through is numerically more stable for
        int[][] shapes = new int[][] { { 5, 3 }, { 5, 100 }, { 1, 5 }, { 1, 100 } };

        double eps = 1e-3;

        for (int[] shape : shapes) {
            INDArray orig = Nd4j.rand(shape);

            INDArray orig1 = orig.dup();
            INDArray orig2 = orig.dup();

            //First: standard log(softmax)
            Nd4j.getExecutioner().exec(new SoftMax(orig1), 1);
            Nd4j.getExecutioner().exec(new Log(orig1));

            //Second: LogSoftMax op
            Nd4j.getExecutioner().exec(new LogSoftMax(orig2), 1);

            for (int i = 0; i < shape[0]; i++) {
                for (int j = 0; j < shape[1]; j++) {
                    double o1 = orig1.getDouble(i);
                    double o2 = orig2.getDouble(i);
                    if (Math.abs(o1 - o2) > eps) {
                        System.out.println();
                    }
                    assertEquals(o1, o2, eps);
                }
            }
        }
    }

    @Test
    public void testSum() {
        INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] { 2, 2, 2 });
        INDArray test = Nd4j.create(new float[] { 3, 7, 11, 15 }, new int[] { 2, 2 });

        INDArray sum = n.sum(-1);
        assertEquals(test, sum);
        INDArray sumZero = n.sum(0);
        INDArray assertion = Nd4j.create(new double[] { 6, 8, 10, 12 }, new int[] { 2, 2 });
        assertEquals(assertion, sumZero);
        INDArray sumOne = n.sum(1);
        for (int i = 0; i < n.tensorssAlongDimension(1); i++) {
            System.out.println(n.tensorAlongDimension(i, 1));
        }
        INDArray assertionTwo = Nd4j.create(new double[] { 4, 6, 12, 14 }, new int[] { 2, 2 });
        assertEquals(assertionTwo, sumOne);
    }

    @Test
    public void testArgMax() {
        INDArray toArgMax = Nd4j.linspace(1, 24, 24).reshape(4, 3, 2);
        System.out.println(toArgMax.tensorssAlongDimension(0));
        int elementWise = toArgMax.tensorAlongDimension(0, 0).elementWiseStride();
        for (int i = 0; i < toArgMax.tensorssAlongDimension(0); i++) {
            System.out.println(toArgMax.tensorAlongDimension(i, 0));
        }
        INDArray tensor = toArgMax.tensorAlongDimension(0, 0);
        System.out.println(toArgMax.max(0));
        System.out.println();
    }

    @Test
    public void testElementWiseOp() {
        Transforms.sigmoid(Nd4j.ones(5, 5));
    }

    @Test
    public void testTensorAlongDimension() {
        int[] shape = new int[] { 4, 5, 7 };
        int length = ArrayUtil.prod(shape);
        INDArray arr = Nd4j.linspace(1, length, length).reshape(shape);

        int[] dim0s = { 0, 1, 2, 0, 1, 2 };
        int[] dim1s = { 1, 0, 0, 2, 2, 1 };

        double[] sums = { 1350., 1350., 1582, 1582, 630, 630 };

        for (int i = 0; i < dim0s.length; i++) {
            int firstDim = dim0s[i];
            int secondDim = dim1s[i];
            INDArray tad = arr.tensorAlongDimension(0, firstDim, secondDim);
            assertEquals("I " + i + " failed ", sums[i], tad.sumNumber().doubleValue(), 1e-1);
        }
    }

    @Test
    public void testNorm2Double() {
        Nd4j.dtype = DataBuffer.Type.DOUBLE;
        INDArray n = Nd4j.create(new double[] { 1, 2, 3, 4 });
        double assertion = 5.47722557505;
        double norm3 = n.norm2Number().doubleValue();
        assertEquals(assertion, norm3, 1e-1);

        INDArray row = Nd4j.create(new double[] { 1, 2, 3, 4 }, new int[] { 2, 2 });
        INDArray row1 = row.getRow(1);
        double norm2 = row1.norm2Number().doubleValue();
        double assertion2 = 5.0f;
        assertEquals(assertion2, norm2, 1e-1);

    }

    @Test
    public void testNorm2() {
        INDArray n = Nd4j.create(new float[] { 1, 2, 3, 4 });
        float assertion = 5.47722557505f;
        float norm3 = n.norm2Number().floatValue();
        assertEquals(assertion, norm3, 1e-1);

        INDArray row = Nd4j.create(new float[] { 1, 2, 3, 4 }, new int[] { 2, 2 });
        INDArray row1 = row.getRow(1);
        float norm2 = row1.norm2Number().floatValue();
        float assertion2 = 5.0f;
        assertEquals(assertion2, norm2, 1e-1);
    }

    @Test
    public void testLength() {
        INDArray values = Nd4j.create(2, 2);
        INDArray values2 = Nd4j.create(2, 2);

        values.put(0, 0, 0);
        values2.put(0, 0, 2);
        values.put(1, 0, 0);
        values2.put(1, 0, 2);
        values.put(0, 1, 0);
        values2.put(0, 1, 0);
        values.put(1, 1, 2);
        values2.put(1, 1, 2);

        for (int i = 0; i < values.tensorssAlongDimension(1); i++) {
            System.out.println("X tad " + i + " is " + values.tensorAlongDimension(i, 1));
            System.out.println("Y tad " + i + " is " + values2.tensorAlongDimension(i, 1));
        }

        INDArray expected = Nd4j.repeat(Nd4j.scalar(2), 2).reshape(2, 1);

        Accumulation accum = Nd4j.getOpFactory().createAccum("euclidean", values, values2);
        INDArray results = Nd4j.getExecutioner().exec(accum, 1);
        assertEquals(expected, results);

    }

    @Test
    public void testDivRowVector() {
        INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
        arr.diviRowVector(Nd4j.linspace(1, 2, 2));
        INDArray assertion = Nd4j.create(new double[][] { { 1, 1 }, { 3, 2 } });

        assertEquals(assertion, arr);
    }

    @Test
    public void testMulRowVector() {
        INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
        arr.muliRowVector(Nd4j.linspace(1, 2, 2));
        INDArray assertion = Nd4j.create(new double[][] { { 1, 4 }, { 3, 8 } });

        assertEquals(assertion, arr);
    }

    @Test
    public void testMMulColVectorRowVectorMixedOrder() {
        INDArray colVec = Nd4j.ones(5, 1);
        INDArray rowVec = Nd4j.ones(1, 5);
        INDArray out = rowVec.mmul(colVec);
        assertArrayEquals(out.shape(), new int[] { 1, 1 });
        assertTrue(out.equals(Nd4j.ones(1, 1).muli(5)));

        INDArray colVectorC = Nd4j.create(new int[] { 5, 1 }, 'c');
        INDArray rowVectorF = Nd4j.create(new int[] { 1, 5 }, 'f');
        for (int i = 0; i < colVectorC.length(); i++)
            colVectorC.putScalar(i, 1.0);
        for (int i = 0; i < rowVectorF.length(); i++)
            rowVectorF.putScalar(i, 1.0);
        assertTrue(colVec.equals(colVectorC));
        assertTrue(rowVec.equals(rowVectorF));

        INDArray outCF = rowVectorF.mmul(colVectorC);
        assertArrayEquals(outCF.shape(), new int[] { 1, 1 });
        assertTrue(outCF.equals(Nd4j.ones(1, 1).muli(5)));
    }

    @Test
    public void testNdVectorOpLinSpace() {
        int[] shape = { 5, 7, 9, 11, 13 };
        INDArray orig = Nd4j.linspace(1, ArrayUtil.prod(shape), ArrayUtil.prod(shape)).reshape(shape);
        int dimension = 0;
        System.out.println(orig.tensorssAlongDimension(dimension));
        for (int i = 0; i < 5; i++) {
            StringBuffer sb = new StringBuffer();
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            for (int j = 0; j < tad.length(); j++) {
                sb.append(tad.get(NDArrayIndex.point(j)).offset());
                sb.append(",");
            }
            System.out.println(sb);
        }
        System.out.println();
        INDArray vector = Nd4j.linspace(1, shape[dimension], shape[dimension]);
        BroadcastOp op = new BroadcastAddOp(orig, vector, orig.dup(), dimension);
        Nd4j.getExecutioner().exec(op);
        for (int i = 0; i < 5; i++)
            System.out.println(op.z().tensorAlongDimension(i, dimension));
        int opNum = 0;
        //Compare expected vs. actual:
        for (int i = 0; i < orig.tensorssAlongDimension(dimension); i++) {
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            INDArray zDim = op.z().tensorAlongDimension(i, dimension);
            INDArray assertion = tad.add(vector);
            assertEquals("Failed on tad with original tad " + tad + " at " + i, assertion, zDim);
        }
        NdIndexIterator iter = new NdIndexIterator(orig.shape());
        while (iter.hasNext()) {
            int[] next = iter.next();
            double origValue = orig.getDouble(next);
            double vectorValue = vector.getDouble(next[dimension]); //current index in vector
            double exp;
            switch (opNum) {
            case 0:
                exp = origValue + vectorValue;
                break;
            case 1:
                exp = vectorValue;
                break;
            case 2:
                exp = origValue / vectorValue;
                break;
            case 3:
                exp = origValue * vectorValue;
                break;
            case 4:
                exp = vectorValue / origValue;
                break;
            case 5:
                exp = vectorValue - origValue;
                break;
            case 6:
                exp = origValue - vectorValue;
                break;
            default:
                throw new RuntimeException();
            }

            double actual = op.z().getDouble(next);
            double relError = Math.abs(exp - actual) / (Math.abs(exp) + Math.abs(actual));
            assertTrue("Failed on rank " + Arrays.toString(shape), relError < 1e-6);

        }
    }

    @Test
    public void testNdVectorOpLinSpaceDiv() {
        int[] shape = { 5, 7, 9, 11, 13 };
        INDArray orig = Nd4j.linspace(1, ArrayUtil.prod(shape), ArrayUtil.prod(shape)).reshape(shape);
        int dimension = 0;
        System.out.println(orig.tensorssAlongDimension(dimension));
        for (int i = 0; i < 5; i++) {
            StringBuffer sb = new StringBuffer();
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            for (int j = 0; j < tad.length(); j++) {
                sb.append(tad.get(NDArrayIndex.point(j)).offset());
                sb.append(",");
            }
            System.out.println(sb);
        }
        System.out.println();
        INDArray vector = Nd4j.linspace(1, shape[dimension], shape[dimension]);
        BroadcastOp op = new BroadcastDivOp(orig, vector, orig.dup(), dimension);
        Nd4j.getExecutioner().exec(op);
        for (int i = 0; i < 5; i++)
            System.out.println(op.z().tensorAlongDimension(i, dimension));
        int opNum = 2;
        //Compare expected vs. actual:
        for (int i = 0; i < orig.tensorssAlongDimension(dimension); i++) {
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            INDArray zDim = op.z().tensorAlongDimension(i, dimension);
            INDArray assertion = tad.div(vector);
            assertEquals("Failed on tad with original tad " + tad + " at " + i, assertion, zDim);
        }
        NdIndexIterator iter = new NdIndexIterator(orig.shape());
        while (iter.hasNext()) {
            int[] next = iter.next();
            double origValue = orig.getDouble(next);
            double vectorValue = vector.getDouble(next[dimension]); //current index in vector
            double exp;
            switch (opNum) {
            case 0:
                exp = origValue + vectorValue;
                break;
            case 1:
                exp = vectorValue;
                break;
            case 2:
                exp = origValue / vectorValue;
                break;
            case 3:
                exp = origValue * vectorValue;
                break;
            case 4:
                exp = vectorValue / origValue;
                break;
            case 5:
                exp = vectorValue - origValue;
                break;
            case 6:
                exp = origValue - vectorValue;
                break;
            default:
                throw new RuntimeException();
            }

            double actual = op.z().getDouble(next);
            double relError = Math.abs(exp - actual) / (Math.abs(exp) + Math.abs(actual));
            assertTrue("Failed on rank " + Arrays.toString(shape), relError < 1e-6);

        }
    }

    @Test
    public void testFiveBySevenDimOne() {
        INDArray orig = Nd4j.linspace(1, 35, 35).reshape(5, 7);
        INDArray vector = Nd4j.linspace(1, 7, 7);
        int dimension = 1;
        System.out.println(orig.tensorssAlongDimension(dimension));
        for (int i = 0; i < 5; i++)
            System.out.println(orig.tensorAlongDimension(i, dimension));
        System.out.println();
        BroadcastOp op = new BroadcastAddOp(orig, vector, orig.dup(), dimension);
        Nd4j.getExecutioner().exec(op);
        //Compare expected vs. actual:
        for (int i = 0; i < orig.tensorssAlongDimension(dimension); i++) {
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            INDArray zDim = op.z().tensorAlongDimension(i, dimension);
            INDArray assertion = tad.add(vector);
            assertEquals("Failed on tad with original tad " + tad + " at " + i, assertion, zDim);
        }

        NdIndexIterator iter = new NdIndexIterator(orig.shape());
        int[] shape = { 5, 7 };
        int opNum = 0;
        while (iter.hasNext()) {
            int[] next = iter.next();
            double origValue = orig.getDouble(next);
            double vectorValue = vector.getDouble(next[dimension]); //current index in vector
            double exp;
            switch (opNum) {
            case 0:
                exp = origValue + vectorValue;
                break;
            case 1:
                exp = vectorValue;
                break;
            case 2:
                exp = origValue / vectorValue;
                break;
            case 3:
                exp = origValue * vectorValue;
                break;
            case 4:
                exp = vectorValue / origValue;
                break;
            case 5:
                exp = vectorValue - origValue;
                break;
            case 6:
                exp = origValue - vectorValue;
                break;
            default:
                throw new RuntimeException();
            }

            double actual = op.z().getDouble(next);
            double relError = Math.abs(exp - actual) / (Math.abs(exp) + Math.abs(actual));
            assertTrue("Failed on rank " + Arrays.toString(shape), relError < 1e-6);

        }
    }

    @Test
    public void testFiveBySevenRDiv() {
        INDArray orig = Nd4j.linspace(1, 35, 35).reshape(5, 7);
        INDArray vector = Nd4j.linspace(1, 5, 5);
        int dimension = 0;
        System.out.println(orig.tensorssAlongDimension(dimension));
        for (int i = 0; i < 5; i++)
            System.out.println(orig.tensorAlongDimension(i, dimension));
        System.out.println();
        BroadcastOp op = new BroadcastRDivOp(orig, vector, orig.dup(), dimension);
        Nd4j.getExecutioner().exec(op);
        //Compare expected vs. actual:
        for (int i = 0; i < orig.tensorssAlongDimension(dimension); i++) {
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            INDArray zDim = op.z().tensorAlongDimension(i, dimension);
            INDArray assertion = tad.rdiv(vector);
            assertEquals("Failed on tad with original tad " + tad + " at " + i, assertion, zDim);
        }

    }

    @Test
    public void testFiveBySevenDiv() {
        INDArray orig = Nd4j.linspace(1, 35, 35).reshape(5, 7);
        INDArray vector = Nd4j.linspace(1, 5, 5);
        int dimension = 0;
        System.out.println(orig.tensorssAlongDimension(dimension));
        for (int i = 0; i < 5; i++)
            System.out.println(orig.tensorAlongDimension(i, dimension));
        System.out.println();
        BroadcastOp op = new BroadcastDivOp(orig, vector, orig.dup(), dimension);
        Nd4j.getExecutioner().exec(op);
        //Compare expected vs. actual:
        for (int i = 0; i < orig.tensorssAlongDimension(dimension); i++) {
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            INDArray zDim = op.z().tensorAlongDimension(i, dimension);
            INDArray assertion = tad.div(vector);
            assertEquals("Failed on tad with original tad " + tad + " at " + i, assertion, zDim);
        }

    }

    @Test
    public void testFiveBySeven() {
        INDArray orig = Nd4j.linspace(1, 35, 35).reshape(5, 7);
        INDArray vector = Nd4j.linspace(1, 5, 5);
        int dimension = 0;
        System.out.println(orig.tensorssAlongDimension(dimension));
        for (int i = 0; i < 5; i++)
            System.out.println(orig.tensorAlongDimension(i, dimension));
        System.out.println();
        BroadcastOp op = new BroadcastAddOp(orig, vector, orig.dup(), dimension);
        Nd4j.getExecutioner().exec(op);
        //Compare expected vs. actual:
        for (int i = 0; i < orig.tensorssAlongDimension(dimension); i++) {
            INDArray tad = orig.tensorAlongDimension(i, dimension);
            INDArray zDim = op.z().tensorAlongDimension(i, dimension);
            INDArray assertion = tad.add(vector);
            assertEquals("Failed on tad with original tad " + tad + " at " + i, assertion, zDim);
        }

        NdIndexIterator iter = new NdIndexIterator(orig.shape());
        int[] shape = { 5, 7 };
        int opNum = 0;
        while (iter.hasNext()) {
            int[] next = iter.next();
            double origValue = orig.getDouble(next);
            double vectorValue = vector.getDouble(next[dimension]); //current index in vector
            double exp;
            switch (opNum) {
            case 0:
                exp = origValue + vectorValue;
                break;
            case 1:
                exp = vectorValue;
                break;
            case 2:
                exp = origValue / vectorValue;
                break;
            case 3:
                exp = origValue * vectorValue;
                break;
            case 4:
                exp = vectorValue / origValue;
                break;
            case 5:
                exp = vectorValue - origValue;
                break;
            case 6:
                exp = origValue - vectorValue;
                break;
            default:
                throw new RuntimeException();
            }

            double actual = op.z().getDouble(next);
            double relError = Math.abs(exp - actual) / (Math.abs(exp) + Math.abs(actual));
            assertTrue("Failed on rank " + Arrays.toString(shape), relError < 1e-6);

        }
    }

    @Test
    public void testColumnVectorAdd() {
        INDArray vector = Nd4j.create(new double[] { 0.8183500170707703, 0.5002227425575256, 0.810189425945282,
                0.09596852213144302, 0.2189500331878662, 0.2587190568447113, 0.4681057631969452 });
        INDArray matrix = Nd4j.create(new double[] { 1.7479660511016846, 0.8165982961654663, 0.9941082000732422,
                0.30052879452705383, 0.7866750359535217, 0.8542637825012207, 1.4326202869415283, 1.471527099609375,
                1.249129295349121, 1.4637593030929565, 0.8436833620071411, 1.1802568435668945, 0.26710736751556396,
                0.5745501518249512, 1.935403823852539, 1.6568565368652344, 2.4301915168762207, 1.0641130208969116,
                1.4025475978851318, 1.2411234378814697, 1.5786868333816528, 2.354153633117676, 1.4680445194244385,
                1.9459636211395264, 0.6315816640853882, 1.1675891876220703, 1.5114526748657227, 1.6130852699279785,
                3.245872735977173, 1.6715824604034424, 2.4574174880981445, 1.0882757902145386, 1.560572624206543,
                0.8008333444595337, 1.8960646390914917 }, new int[] { 5, 7 });
        int dimension = 1;
        BroadcastAddOp op = new BroadcastAddOp(matrix, vector, matrix.dup(), 1);
        INDArray assertion = matrix.dup();
        for (int i = 0; i < assertion.tensorssAlongDimension(dimension); i++) {
            assertion.tensorAlongDimension(i, dimension).addi(vector);
        }

        Nd4j.getExecutioner().exec(op);
        assertEquals(assertion, op.z());

    }

    @Test
    public void testDimensionOneLengthSeven() {
        INDArray seven = Nd4j.linspace(1, 7, 7);
        int[] tensorShape = { 5, 7, 9, 11, 13 };
        int len = ArrayUtil.prod(tensorShape);
        int dimension = 1;
        INDArray arr = Nd4j.linspace(1, len, len).reshape(tensorShape);
        BroadcastAddOp op = new BroadcastAddOp(arr, seven, arr, dimension);
        INDArray dup = arr.dup();
        Nd4j.getExecutioner().exec(op);

        for (int i = 0; i < 5; i++) {
            System.out.println("Adding vector " + seven + " to tad " + dup.tensorAlongDimension(i, dimension));
            System.out.println(
                    "Comparing against  vector " + seven + " to tad " + arr.tensorAlongDimension(i, dimension));

        }
        System.out.println(op.z());
    }

    @Test
    public void testNdVectorOp() {
        //Test 2d, 3d, ..., 6d vector ops

        Nd4j.getRandom().setSeed(12345);
        int[] maxShape = new int[] { 5, 7, 9, 11, 13, 15 };

        for (int opNum = 0; opNum < 6; opNum++) {
            for (int rank = 2; rank < maxShape.length; rank++) {
                int[] shape = Arrays.copyOfRange(maxShape, 0, rank);
                INDArray orig = Nd4j.rand(shape);

                for (int i = 0; i < rank; i++) { //Test ops for each dimension
                    INDArray arr = orig.dup();
                    INDArray vector = i == 0 ? Nd4j.rand(1, shape[i]) : Nd4j.rand(shape[i], 1);
                    System.out.println("Executed rank " + rank + " and dimension " + i + " with vector " + vector
                            + " and array of shape " + Arrays.toString(arr.shape()));
                    BroadcastOp op;
                    switch (opNum) {
                    case 0:
                        op = new BroadcastAddOp(arr, vector, arr.dup(), i);
                        break;
                    case 1:
                        op = new BroadcastCopyOp(arr, vector, arr, i);
                        break;
                    case 2:
                        op = new BroadcastDivOp(arr, vector, arr.dup(), i);
                        break;
                    case 3:
                        op = new BroadcastMulOp(arr, vector, arr.dup(), i);
                        break;
                    case 4:
                        op = new BroadcastRDivOp(arr, vector, arr.dup(), i);
                        break;
                    case 5:
                        op = new BroadcastRSubOp(arr, vector, arr.dup(), i);
                        break;
                    case 6:
                        op = new BroadcastSubOp(arr, vector, arr.dup(), i);
                        break;
                    default:
                        throw new RuntimeException();
                    }

                    StopWatch watch = new StopWatch();
                    watch.start();
                    System.out.println("About to execute op " + op.name());
                    Nd4j.getExecutioner().exec(op);
                    watch.stop();

                    System.out.println("After execution " + watch.getNanoTime() + " nanoseconds with "
                            + op.x().tensorssAlongDimension(i));
                    INDArray assertion = arr.dup();
                    for (int j = 0; j < arr.tensorssAlongDimension(i); j++) {
                        switch (opNum) {
                        case 0:
                            assertion.tensorAlongDimension(j, i).addi(vector);
                            break;
                        case 1:
                            assertion.tensorAlongDimension(j, i).assign(vector);
                            break;
                        case 2:
                            assertion.tensorAlongDimension(j, i).divi(vector);
                            break;
                        case 3:
                            assertion.tensorAlongDimension(j, i).muli(vector);
                            break;
                        case 4:
                            assertion.tensorAlongDimension(j, i).rdivi(vector);
                            break;
                        case 5:
                            assertion.tensorAlongDimension(j, i).rsubi(vector);
                            break;
                        case 6:
                            assertion.tensorAlongDimension(j, i).subi(vector);
                            break;
                        default:
                            throw new RuntimeException();
                        }
                    }

                    assertEquals(assertion, op.z());
                }
            }
        }
    }

    @Test
    public void testCosineSim() {
        Nd4j.dtype = DataBuffer.Type.FLOAT;

        INDArray vec1 = Nd4j.create(new double[] { 1, 2, 3, 4 });
        INDArray vec2 = Nd4j.create(new double[] { 1, 2, 3, 4 });
        double sim = Transforms.cosineSim(vec1, vec2);
        assertEquals(1, sim, 1e-1);

        INDArray vec3 = Nd4j.create(new float[] { 0.2f, 0.3f, 0.4f, 0.5f });
        INDArray vec4 = Nd4j.create(new float[] { 0.6f, 0.7f, 0.8f, 0.9f });
        sim = Transforms.cosineSim(vec3, vec4);
        assertEquals(0.98, sim, 1e-1);

    }

    @Test
    public void testSumWithRow1() {
        //Works:
        INDArray array2d = Nd4j.ones(1, 10);
        array2d.sum(0); //OK
        array2d.sum(1); //OK

        INDArray array3d = Nd4j.ones(1, 10, 10);
        array3d.sum(0); //OK
        array3d.sum(1); //OK
        array3d.sum(2); //java.lang.IllegalArgumentException: Illegal index 100 derived from 9 with offset of 10 and stride of 10

        INDArray array4d = Nd4j.ones(1, 10, 10, 10);
        array4d.sum(0); //OK
        array4d.sum(1); //OK
        array4d.sum(2); //java.lang.IllegalArgumentException: Illegal index 1000 derived from 9 with offset of 910 and stride of 10
        array4d.sum(3); //java.lang.IllegalArgumentException: Illegal index 1000 derived from 9 with offset of 100 and stride of 100

        INDArray array5d = Nd4j.ones(1, 10, 10, 10, 10);
        array5d.sum(0); //OK
        array5d.sum(1); //OK
        array5d.sum(2); //java.lang.IllegalArgumentException: Illegal index 10000 derived from 9 with offset of 9910 and stride of 10
        array5d.sum(3); //java.lang.IllegalArgumentException: Illegal index 10000 derived from 9 with offset of 9100 and stride of 100
        array5d.sum(4); //java.lang.IllegalArgumentException: Illegal index 10000 derived from 9 with offset of 1000 and stride of 1000
    }

    @Test
    public void testToOffsetZero() {
        INDArray matrix = Nd4j.rand(3, 5);
        INDArray rowOne = matrix.getRow(1);
        INDArray row1Copy = Shape.toOffsetZero(rowOne);
        assertEquals(rowOne, row1Copy);
        INDArray rows = matrix.getRows(1, 2);
        INDArray rowsOffsetZero = Shape.toOffsetZero(rows);
        assertEquals(rows, rowsOffsetZero);

        INDArray tensor = Nd4j.rand(new int[] { 3, 3, 3 });
        INDArray getTensor = tensor.slice(1).slice(1);
        INDArray getTensorZero = Shape.toOffsetZero(getTensor);
        assertEquals(getTensor, getTensorZero);

    }

    @Test
    public void testSumLeadingTrailingZeros() {
        testSumHelper(1, 5, 5);
        testSumHelper(5, 5, 1);
        testSumHelper(1, 5, 1);

        testSumHelper(1, 5, 5, 5);
        testSumHelper(5, 5, 5, 1);
        testSumHelper(1, 5, 5, 1);

        testSumHelper(1, 5, 5, 5, 5);
        testSumHelper(5, 5, 5, 5, 1);
        testSumHelper(1, 5, 5, 5, 1);

        testSumHelper(1, 5, 5, 5, 5, 5);
        testSumHelper(5, 5, 5, 5, 5, 1);
        testSumHelper(1, 5, 5, 5, 5, 1);
    }

    private void testSumHelper(int... shape) {
        INDArray array = Nd4j.ones(shape);
        for (int i = 0; i < shape.length; i++) {
            for (int j = 0; j < array.vectorsAlongDimension(i); j++) {
                INDArray vec = array.vectorAlongDimension(j, i);
            }
            array.sum(i);
        }
    }

    @Test
    public void testMultipleThreads() throws InterruptedException {
        int numThreads = 10;
        final INDArray array = Nd4j.rand(300, 300);
        final INDArray expected = array.dup().mmul(array).mmul(array).div(array).div(array);
        final AtomicInteger correct = new AtomicInteger();
        final CountDownLatch latch = new CountDownLatch(numThreads);
        System.out.println("Running on " + ContextHolder.getInstance().deviceNum());
        ExecutorService executors = ExecutorServiceProvider.getExecutorService();

        for (int x = 0; x < numThreads; x++) {
            executors.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        int total = 10;
                        int right = 0;
                        for (int x = 0; x < total; x++) {
                            StopWatch watch = new StopWatch();
                            watch.start();
                            INDArray actual = array.dup().mmul(array).mmul(array).div(array).div(array);
                            watch.stop();
                            if (expected.equals(actual))
                                right++;
                        }

                        if (total == right)
                            correct.incrementAndGet();
                    } finally {
                        latch.countDown();
                    }

                }
            });
        }

        latch.await();

        assertEquals(numThreads, correct.get());

    }

}