Tsne.java Source code

Java tutorial

Introduction

Here is the source code for Tsne.java

Source

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.indexing.functions.Zero;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

import java.io.*;
import java.util.List;

import static org.nd4j.linalg.factory.Nd4j.*;
import static org.nd4j.linalg.ops.transforms.Transforms.*;

/**
 * Tsne calculation
 * @author Adam Gibson
 */
public class Tsne implements Serializable {

    protected int maxIter = 1000;
    protected double realMin = Nd4j.EPS_THRESHOLD;
    protected double initialMomentum = 0.5;
    protected double finalMomentum = 0.8;
    protected double minGain = 1e-2;
    protected double momentum = initialMomentum;
    protected int switchMomentumIteration = 100;
    protected boolean normalize = true;
    protected boolean usePca = false;
    protected int stopLyingIteration = 250;
    protected double tolerance = 1e-5;
    protected double learningRate = 500;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad = true;
    protected double perplexity = 30;
    protected INDArray gains, yIncs;
    protected INDArray y;
    protected transient IterationListener iterationListener;
    protected static ClassPathResource r = new ClassPathResource("/scripts/tsne.py");
    protected static final Logger log = LoggerFactory.getLogger(Tsne.class);

    public Tsne() {
    }

    public Tsne(int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum,
            int switchMomentumIteration, boolean normalize, boolean usePca, int stopLyingIteration,
            double tolerance, double learningRate, boolean useAdaGrad, double perplexity, double minGain) {
        this.tolerance = tolerance;
        this.minGain = minGain;
        this.useAdaGrad = useAdaGrad;
        this.learningRate = learningRate;
        this.stopLyingIteration = stopLyingIteration;
        this.maxIter = maxIter;
        this.realMin = realMin;
        this.normalize = normalize;
        this.initialMomentum = initialMomentum;
        this.usePca = usePca;
        this.finalMomentum = finalMomentum;
        this.momentum = momentum;
        this.switchMomentumIteration = switchMomentumIteration;
        this.perplexity = perplexity;
    }

    /**
     * Computes a gaussian kernel
     * given a vector of squared distance distances
     *
     * @param d the data
     * @param beta
     * @return
     */
    public Pair<INDArray, INDArray> hBeta(INDArray d, double beta) {
        INDArray P = exp(d.neg().muli(beta));
        INDArray sum = P.sum(Integer.MAX_VALUE);
        INDArray H = log(sum).addi(d.mul(P).sum(0).muli(beta).divi(sum));
        P.divi(sum);
        return new Pair<>(H, P);
    }

    /**
     * Convert data to probability
     * co-occurrences (aka calculating the kernel)
     * @param d the data to convert
     * @param u the perplexity of the model
     * @return the probabilities of co-occurrence
     */
    public INDArray computeGaussianPerplexity(final INDArray d, double u) {
        int n = d.rows();
        final INDArray p = zeros(n, n);
        final INDArray beta = ones(n, 1);
        final double logU = Math.log(u);

        log.info("Calculating probabilities of data similarities..");
        for (int i = 0; i < n; i++) {
            if (i % 500 == 0 && i > 0)
                log.info("Handled " + i + " records");

            double betaMin = Double.NEGATIVE_INFINITY;
            double betaMax = Double.POSITIVE_INFINITY;
            NDArrayIndex[] range = new NDArrayIndex[] {
                    NDArrayIndex.concat(NDArrayIndex.interval(0, i), NDArrayIndex.interval(i + 1, d.columns())) };

            INDArray row = d.slice(i).get(range);
            Pair<INDArray, INDArray> pair = hBeta(row, beta.getDouble(i));
            INDArray hDiff = pair.getFirst().sub(logU);
            int tries = 0;

            //while hdiff > tolerance
            while (BooleanIndexing.and(abs(hDiff), Conditions.greaterThan(tolerance)) && tries < 50) {
                //if hdiff > 0
                if (BooleanIndexing.and(hDiff, Conditions.greaterThan(0))) {
                    if (Double.isInfinite(betaMax))
                        beta.putScalar(i, beta.getDouble(i) * 2.0);
                    else
                        beta.putScalar(i, (beta.getDouble(i) + betaMax) / 2.0);
                    betaMin = beta.getDouble(i);
                } else {
                    if (Double.isInfinite(betaMin))
                        beta.putScalar(i, beta.getDouble(i) / 2.0);
                    else
                        beta.putScalar(i, (beta.getDouble(i) + betaMin) / 2.0);
                    betaMax = beta.getDouble(i);
                }

                pair = hBeta(row, beta.getDouble(i));
                hDiff = pair.getFirst().subi(logU);
                tries++;
            }

            p.slice(i).put(range, pair.getSecond());

        }

        //dont need data in memory after
        log.info("Mean value of sigma " + sqrt(beta.rdiv(1)).mean(Integer.MAX_VALUE));
        BooleanIndexing.applyWhere(p, Conditions.isNan(), new Value(realMin));

        //set 0 along the diagonal
        INDArray permute = p.transpose();

        INDArray pOut = p.add(permute);

        pOut.divi(pOut.sum(Integer.MAX_VALUE));
        BooleanIndexing.applyWhere(pOut, Conditions.lessThan(Nd4j.EPS_THRESHOLD), new Value(Nd4j.EPS_THRESHOLD));
        //ensure no nans
        return pOut;

    }

    /**
     *
     * @param X
     * @param nDims
     * @param perplexity
     */
    public INDArray calculate(INDArray X, int nDims, double perplexity) {
        if (usePca)
            X = PCA.pca(X, Math.min(50, X.columns()), normalize);
        //normalization (don't normalize again after pca)
        if (normalize) {
            X.subi(X.min(Integer.MAX_VALUE));
            X = X.divi(X.max(Integer.MAX_VALUE));
            X = X.subiRowVector(X.mean(0));
        }

        if (nDims > X.columns())
            nDims = X.columns();

        INDArray sumX = pow(X, 2).sum(1);

        INDArray D = X.mmul(X.transpose()).muli(-2).addRowVector(sumX).transpose().addRowVector(sumX);

        //output
        if (y == null)
            y = randn(X.rows(), nDims, Nd4j.getRandom()).muli(1e-3f);

        INDArray p = computeGaussianPerplexity(D, perplexity);

        //lie for better local minima
        p.muli(4);

        //init adagrad where needed
        if (useAdaGrad) {
            if (adaGrad == null) {
                adaGrad = new AdaGrad(y.shape());
                adaGrad.setMasterStepSize(learningRate);
            }
        }

        for (int i = 0; i < maxIter; i++) {
            step(p, i);

            if (i == switchMomentumIteration)
                momentum = finalMomentum;
            if (i == stopLyingIteration)
                p.divi(4);

            if (iterationListener != null)
                iterationListener.iterationDone(null, i);

        }

        return y;
    }

    /* compute the gradient given the current solution, the probabilities and the constant */
    protected Pair<Double, INDArray> gradient(INDArray p) {
        INDArray sumY = pow(y, 2).sum(1);
        if (yIncs == null)
            yIncs = zeros(y.shape());
        if (gains == null)
            gains = ones(y.shape());

        //Student-t distribution
        //also un normalized q
        INDArray qu = y.mmul(y.transpose()).muli(-2).addiRowVector(sumY).transpose().addiRowVector(sumY).addi(1)
                .rdivi(1);

        int n = y.rows();

        //set diagonal to zero
        doAlongDiagonal(qu, new Zero());

        // normalize to get probabilities
        INDArray q = qu.div(qu.sum(Integer.MAX_VALUE));

        BooleanIndexing.applyWhere(q, Conditions.lessThan(realMin), new Value(realMin));

        INDArray PQ = p.sub(q);

        INDArray yGrads = getYGradient(n, PQ, qu);

        gains = gains.add(.2)
                .muli(yGrads.cond(Conditions.greaterThan(0)).neqi(yIncs.cond(Conditions.greaterThan(0))))
                .addi(gains.mul(0.8)
                        .muli(yGrads.cond(Conditions.greaterThan(0)).eqi(yIncs.cond(Conditions.greaterThan(0)))));

        BooleanIndexing.applyWhere(gains, Conditions.lessThan(minGain), new Value(minGain));

        INDArray gradChange = gains.mul(yGrads);

        if (useAdaGrad)
            gradChange = adaGrad.getGradient(gradChange);
        else
            gradChange.muli(learningRate);

        yIncs.muli(momentum).subi(gradChange);

        double cost = p.mul(log(p.div(q), false)).sum(Integer.MAX_VALUE).getDouble(0);
        return new Pair<>(cost, yIncs);
    }

    public INDArray getYGradient(int n, INDArray PQ, INDArray qu) {
        INDArray yGrads = Nd4j.create(y.shape());
        for (int i = 0; i < n; i++) {
            INDArray sum1 = Nd4j.tile(PQ.getRow(i).mul(qu.getRow(i)), new int[] { y.columns(), 1 }).transpose()
                    .mul(y.getRow(i).broadcast(y.shape()).sub(y)).sum(0);
            yGrads.putRow(i, sum1);
        }

        return yGrads;
    }

    /**
     * An individual iteration
     * @param p the probabilities that certain points
     *          are near each other
     * @param i the iteration (primarily for debugging purposes)
     */
    public void step(INDArray p, int i) {
        Pair<Double, INDArray> costGradient = gradient(p);
        INDArray yIncs = costGradient.getSecond();
        log.info("Cost at iteration " + i + " was " + costGradient.getFirst());
        y.addi(yIncs);
        y.addi(yIncs).subiRowVector(y.mean(0));
        y.subi(Nd4j.tile(y.mean(0), new int[] { y.rows(), 1 }));

    }

    /**
     * Plot tsne (write the coordinates file)
     * @param matrix the matrix to plot
     * @param nDims the number of dimensions
     * @param labels
     * @throws IOException
     */
    public void plot(INDArray matrix, int nDims, List<String> labels) throws IOException {
        plot(matrix, nDims, labels, "coords.csv");
    }

    /**
     * Plot tsne
     * @param matrix the matrix to plot
     * @param nDims the number
     * @param labels
     * @param path the path to write
     * @throws IOException
     */
    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {

        calculate(matrix, nDims, perplexity);

        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), true));

        for (int i = 0; i < y.rows(); i++) {
            if (i >= labels.size())
                break;
            String word = labels.get(i);
            if (word == null)
                continue;
            StringBuffer sb = new StringBuffer();
            INDArray wordVector = y.getRow(i);
            for (int j = 0; j < wordVector.length(); j++) {
                sb.append(wordVector.getDouble(j));
                if (j < wordVector.length() - 1)
                    sb.append(",");
            }

            sb.append(",");
            sb.append(word);
            sb.append(" ");

            sb.append("\n");
            write.write(sb.toString());

        }

        write.flush();
        write.close();
    }

    public INDArray getY() {
        return y;
    }

    public void setY(INDArray y) {
        this.y = y;
    }

    public IterationListener getIterationListener() {
        return iterationListener;
    }

    public void setIterationListener(IterationListener iterationListener) {
        this.iterationListener = iterationListener;
    }

    public static class Builder {
        protected int maxIter = 1000;
        protected double realMin = 1e-12f;
        protected double initialMomentum = 5e-1f;
        protected double finalMomentum = 8e-1f;
        protected double momentum = 5e-1f;
        protected int switchMomentumIteration = 100;
        protected boolean normalize = true;
        protected boolean usePca = false;
        protected int stopLyingIteration = 100;
        protected double tolerance = 1e-5f;
        protected double learningRate = 1e-1f;
        protected boolean useAdaGrad = false;
        protected double perplexity = 30;
        protected double minGain = 1e-1f;

        public Builder minGain(double minGain) {
            this.minGain = minGain;
            return this;
        }

        public Builder perplexity(double perplexity) {
            this.perplexity = perplexity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder tolerance(double tolerance) {
            this.tolerance = tolerance;
            return this;
        }

        public Builder stopLyingIteration(int stopLyingIteration) {
            this.stopLyingIteration = stopLyingIteration;
            return this;
        }

        public Builder usePca(boolean usePca) {
            this.usePca = usePca;
            return this;
        }

        public Builder normalize(boolean normalize) {
            this.normalize = normalize;
            return this;
        }

        public Builder setMaxIter(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Builder setRealMin(double realMin) {
            this.realMin = realMin;
            return this;
        }

        public Builder setInitialMomentum(double initialMomentum) {
            this.initialMomentum = initialMomentum;
            return this;
        }

        public Builder setFinalMomentum(double finalMomentum) {
            this.finalMomentum = finalMomentum;
            return this;
        }

        public Builder setMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            this.switchMomentumIteration = switchMomentumIteration;
            return this;
        }

        public Tsne build() {
            return new Tsne(maxIter, realMin, initialMomentum, finalMomentum, momentum, switchMomentumIteration,
                    normalize, usePca, stopLyingIteration, tolerance, learningRate, useAdaGrad, perplexity,
                    minGain);
        }

    }
}