org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.java

Source

/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.models.embeddings.inmemory;

import com.google.common.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.plot.BarnesHutTsne;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.net.HttpURLConnection;
import java.net.URI;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Default word lookup table
 *
 * @author Adam Gibson
 */
public class InMemoryLookupTable<T extends SequenceElement> implements WeightLookupTable<T> {

    private static final Logger log = LoggerFactory.getLogger(InMemoryLookupTable.class);

    protected INDArray syn0, syn1;
    protected int vectorLength;
    protected transient Random rng = Nd4j.getRandom();
    protected AtomicDouble lr = new AtomicDouble(25e-3);
    protected double[] expTable;
    protected static double MAX_EXP = 6;
    protected long seed = 123;
    //negative sampling table
    protected INDArray table, syn1Neg;
    protected boolean useAdaGrad;
    protected double negative = 0;
    protected boolean useHS = true;
    protected VocabCache<T> vocab;
    protected Map<Integer, INDArray> codes = new ConcurrentHashMap<>();

    protected AdaGrad adaGrad;

    @Getter
    @Setter
    protected Long tableId;

    public InMemoryLookupTable() {
    }

    public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen,
            double negative, boolean useHS) {
        this(vocab, vectorLength, useAdaGrad, lr, gen, negative);
        this.useHS = useHS;
    }

    public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen,
            double negative) {
        this.vocab = vocab;
        this.vectorLength = vectorLength;
        this.useAdaGrad = useAdaGrad;
        this.lr.set(lr);
        this.rng = gen;
        this.negative = negative;
        initExpTable();

        if (useAdaGrad) {
            initAdaGrad();
        }
    }

    protected void initAdaGrad() {
        adaGrad = new AdaGrad(new int[] { vocab.numWords() + 1, vectorLength }, lr.get());
    }

    public double[] getExpTable() {
        return expTable;
    }

    public void setExpTable(double[] expTable) {
        this.expTable = expTable;
    }

    public double getGradient(int column, double gradient) {
        if (adaGrad == null)
            initAdaGrad();
        return adaGrad.getGradient(gradient, column, syn0.shape());
    }

    @Override
    public int layerSize() {
        return vectorLength;
    }

    @Override
    public void resetWeights(boolean reset) {
        if (this.rng == null)
            this.rng = Nd4j.getRandom();

        this.rng.setSeed(seed);

        if (syn0 == null || reset) {
            syn0 = Nd4j.rand(new int[] { vocab.numWords(), vectorLength }, rng).subi(0.5).divi(vectorLength);
            //            INDArray randUnk = Nd4j.rand(1, vectorLength, rng).subi(0.5).divi(vectorLength);
            //            putVector(Word2Vec.UNK, randUnk);
        }
        if ((syn1 == null || reset) && useHS) {
            log.info("Initializing syn1...");
            syn1 = Nd4j.create(syn0.shape());
        }

        initNegative();
    }

    private List<String> fitTnseAndGetLabels(final BarnesHutTsne tsne, final int numWords) {
        INDArray array = Nd4j.create(numWords, vectorLength);
        List<String> labels = new ArrayList<>();
        for (int i = 0; i < numWords && i < vocab.numWords(); i++) {
            labels.add(vocab.wordAtIndex(i));
            array.putRow(i, syn0.slice(i));
        }
        tsne.fit(array);
        return labels;
    }

    @Override
    public void plotVocab(BarnesHutTsne tsne, int numWords, File file) {
        final List<String> labels = fitTnseAndGetLabels(tsne, numWords);
        try {
            tsne.saveAsFile(labels, file.getAbsolutePath());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Render the words via tsne
     */
    @Override
    public void plotVocab(int numWords, File file) {
        BarnesHutTsne tsne = new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.8f).numDimension(2)
                .setMaxIter(1000).build();
        plotVocab(tsne, numWords, file);
    }

    /**
     * Render the words via tsne
     */
    @Override
    public void plotVocab(int numWords, UiConnectionInfo connectionInfo) {
        BarnesHutTsne tsne = new BarnesHutTsne.Builder().normalize(false).setFinalMomentum(0.8f).numDimension(2)
                .setMaxIter(1000).build();
        plotVocab(tsne, numWords, connectionInfo);
    }

    /**
     * Render the words via TSNE
     *
     * @param tsne           the tsne to use
     * @param numWords
     * @param connectionInfo
     */
    @Override
    public void plotVocab(BarnesHutTsne tsne, int numWords, UiConnectionInfo connectionInfo) {
        try {
            final List<String> labels = fitTnseAndGetLabels(tsne, numWords);
            final INDArray reducedData = tsne.getData();
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < reducedData.rows() && i < numWords; i++) {
                String word = labels.get(i);
                INDArray wordVector = reducedData.getRow(i);
                for (int j = 0; j < wordVector.length(); j++) {
                    sb.append(String.valueOf(wordVector.getDouble(j))).append(",");
                }
                sb.append(word);
            }

            String address = connectionInfo.getFirstPart() + "/tsne/post/" + connectionInfo.getSessionId();
            //            System.out.println("ADDRESS: " + address);
            URI uri = new URI(address);

            HttpURLConnection connection = (HttpURLConnection) uri.toURL().openConnection();
            connection.setRequestMethod("POST");
            connection.setRequestProperty("User-Agent", "Mozilla/5.0");
            //            connection.setRequestProperty("Content-Type", "application/json");
            connection.setRequestProperty("Content-Type", "multipart/form-data");
            connection.setDoOutput(true);

            DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
            dos.writeBytes(sb.toString());
            dos.flush();
            dos.close();

            try {
                int responseCode = connection.getResponseCode();
                System.out.println("RESPONSE CODE: " + responseCode);

                if (responseCode != 200) {
                    BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
                    String inputLine;
                    StringBuilder response = new StringBuilder();

                    while ((inputLine = in.readLine()) != null) {
                        response.append(inputLine);
                    }
                    in.close();

                    log.warn("Error posting to remote UI - received response code {}\tContent: {}", response,
                            response.toString());
                }
            } catch (IOException e) {
                log.warn("Error posting to remote UI at {}", uri, e);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * @param codeIndex
     * @param code
     */
    @Override
    public void putCode(int codeIndex, INDArray code) {
        codes.put(codeIndex, code);
    }

    /**
     * Loads the co-occurrences for the given codes
     *
     * @param codes the codes to load
     * @return an ndarray of code.length by layerSize
     */
    @Override
    public INDArray loadCodes(int[] codes) {
        return syn1.getRows(codes);
    }

    public synchronized void initNegative() {
        if (negative > 0 && syn1Neg == null) {
            syn1Neg = Nd4j.zeros(syn0.shape());
            makeTable(Math.max(expTable.length, 100000), 0.75);
        }
    }

    protected void initExpTable() {
        expTable = new double[100000];
        for (int i = 0; i < expTable.length; i++) {
            double tmp = FastMath.exp((i / (double) expTable.length * 2 - 1) * MAX_EXP);
            expTable[i] = tmp / (tmp + 1.0);
        }
    }

    /**
     * Iterate on the given 2 vocab words
     *
     * @param w1 the first word to iterate on
     * @param w2 the second word to iterate on
     * @param nextRandom next random for sampling
     */
    @Override
    @Deprecated
    public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha) {
        if (w2 == null || w2.getIndex() < 0 || w1.getIndex() == w2.getIndex() || w1.getLabel().equals("STOP")
                || w2.getLabel().equals("STOP") || w1.getLabel().equals("UNK") || w2.getLabel().equals("UNK"))
            return;
        //current word vector
        INDArray l1 = this.syn0.slice(w2.getIndex());

        //error for current word and context
        INDArray neu1e = Nd4j.create(vectorLength);

        for (int i = 0; i < w1.getCodeLength(); i++) {
            int code = w1.getCodes().get(i);
            int point = w1.getPoints().get(i);
            if (point >= syn0.rows() || point < 0)
                throw new IllegalStateException("Illegal point " + point);
            //other word vector

            INDArray syn1 = this.syn1.slice(point);

            double dot = Nd4j.getBlasWrapper().dot(l1, syn1);

            if (dot < -MAX_EXP || dot >= MAX_EXP)
                continue;

            int idx = (int) ((dot + MAX_EXP) * ((double) expTable.length / MAX_EXP / 2.0));
            if (idx >= expTable.length)
                continue;

            //score
            double f = expTable[idx];
            //gradient
            double g = useAdaGrad ? w1.getGradient(i, (1 - code - f), lr.get()) : (1 - code - f) * alpha;

            Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, syn1, neu1e);
            Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, l1, syn1);

        }

        int target = w1.getIndex();
        int label;
        //negative sampling
        if (negative > 0)
            for (int d = 0; d < negative + 1; d++) {
                if (d == 0)
                    label = 1;
                else {
                    nextRandom.set(nextRandom.get() * 25214903917L + 11);
                    int idx = Math.abs((int) (nextRandom.get() >> 16) % table.length());

                    target = table.getInt(idx);
                    if (target <= 0)
                        target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1;

                    if (target == w1.getIndex())
                        continue;
                    label = 0;
                }

                if (target >= syn1Neg.rows() || target < 0)
                    continue;

                double f = Nd4j.getBlasWrapper().dot(l1, syn1Neg.slice(target));
                double g;
                if (f > MAX_EXP)
                    g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha;
                else if (f < -MAX_EXP)
                    g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha);
                else
                    g = useAdaGrad ? w1.getGradient(target,
                            label - expTable[(int) ((f + MAX_EXP) * (expTable.length / MAX_EXP / 2))], alpha)
                            : (label - expTable[(int) ((f + MAX_EXP) * (expTable.length / MAX_EXP / 2))]) * alpha;
                if (syn0.data().dataType() == DataBuffer.Type.DOUBLE)
                    Nd4j.getBlasWrapper().axpy(g, syn1Neg.slice(target), neu1e);
                else
                    Nd4j.getBlasWrapper().axpy((float) g, syn1Neg.slice(target), neu1e);

                if (syn0.data().dataType() == DataBuffer.Type.DOUBLE)
                    Nd4j.getBlasWrapper().axpy(g, l1, syn1Neg.slice(target));
                else
                    Nd4j.getBlasWrapper().axpy((float) g, l1, syn1Neg.slice(target));
            }

        if (syn0.data().dataType() == DataBuffer.Type.DOUBLE)
            Nd4j.getBlasWrapper().axpy(1.0, neu1e, l1);

        else
            Nd4j.getBlasWrapper().axpy(1.0f, neu1e, l1);

    }

    public boolean isUseAdaGrad() {
        return useAdaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public double getNegative() {
        return negative;
    }

    public void setUseHS(boolean useHS) {
        this.useHS = useHS;
    }

    public void setNegative(double negative) {
        this.negative = negative;
    }

    /**
     * Iterate on the given 2 vocab words
     *
     * @param w1 the first word to iterate on
     * @param w2 the second word to iterate on
     */
    @Deprecated
    @Override
    public void iterate(T w1, T w2) {

    }

    /**
     * Reset the weights of the cache
     */
    @Override
    public void resetWeights() {
        resetWeights(true);
    }

    protected void makeTable(int tableSize, double power) {
        int vocabSize = syn0.rows();
        table = Nd4j.create(tableSize);
        double trainWordsPow = 0.0;
        for (String word : vocab.words()) {
            trainWordsPow += Math.pow(vocab.wordFrequency(word), power);
        }

        int wordIdx = 0;
        String word = vocab.wordAtIndex(wordIdx);
        double d1 = Math.pow(vocab.wordFrequency(word), power) / trainWordsPow;
        for (int i = 0; i < tableSize; i++) {
            table.putScalar(i, wordIdx);
            double mul = i * 1.0 / (double) tableSize;
            if (mul > d1) {
                if (wordIdx < vocabSize - 1)
                    wordIdx++;
                word = vocab.wordAtIndex(wordIdx);
                String wordAtIndex = vocab.wordAtIndex(wordIdx);
                if (word == null)
                    continue;
                d1 += Math.pow(vocab.wordFrequency(wordAtIndex), power) / trainWordsPow;
            }
        }
    }

    /**
     * Inserts a word vector
     *
     * @param word   the word to insert
     * @param vector the vector to insert
     */
    @Override
    public void putVector(String word, INDArray vector) {
        if (word == null)
            throw new IllegalArgumentException("No null words allowed");
        if (vector == null)
            throw new IllegalArgumentException("No null vectors allowed");
        int idx = vocab.indexOf(word);
        syn0.slice(idx).assign(vector);

    }

    public INDArray getTable() {
        return table;
    }

    public void setTable(INDArray table) {
        this.table = table;
    }

    public INDArray getSyn1Neg() {
        return syn1Neg;
    }

    public void setSyn1Neg(INDArray syn1Neg) {
        this.syn1Neg = syn1Neg;
    }

    /**
     * @param word
     * @return
     */
    @Override
    public INDArray vector(String word) {
        if (word == null)
            return null;
        int idx = vocab.indexOf(word);
        if (idx < 0) {
            idx = vocab.indexOf(Word2Vec.DEFAULT_UNK);
            if (idx < 0)
                return null;
        }
        return syn0.getRow(idx);
    }

    @Override
    public void setLearningRate(double lr) {
        this.lr.set(lr);
    }

    @Override
    public Iterator<INDArray> vectors() {
        return new WeightIterator();
    }

    @Override
    public INDArray getWeights() {
        return syn0;
    }

    protected class WeightIterator implements Iterator<INDArray> {
        protected int currIndex = 0;

        @Override
        public boolean hasNext() {
            return currIndex < syn0.rows();
        }

        @Override
        public INDArray next() {
            INDArray ret = syn0.slice(currIndex);
            currIndex++;
            return ret;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

    public INDArray getSyn0() {
        return syn0;
    }

    public void setSyn0(INDArray syn0) {
        this.syn0 = syn0;
    }

    public INDArray getSyn1() {
        return syn1;
    }

    public void setSyn1(INDArray syn1) {
        this.syn1 = syn1;
    }

    @Override
    public VocabCache<T> getVocabCache() {
        return vocab;
    }

    public void setVectorLength(int vectorLength) {
        this.vectorLength = vectorLength;
    }

    /**
     * This method is deprecated, since all logic was pulled out from this class and is not used anymore.
     * However this method will be around for a while, due to backward compatibility issues.
     * @return initial learning rate
     */
    @Deprecated
    public AtomicDouble getLr() {
        return lr;
    }

    public void setLr(AtomicDouble lr) {
        this.lr = lr;
    }

    public VocabCache getVocab() {
        return vocab;
    }

    public void setVocab(VocabCache vocab) {
        this.vocab = vocab;
    }

    public Map<Integer, INDArray> getCodes() {
        return codes;
    }

    public void setCodes(Map<Integer, INDArray> codes) {
        this.codes = codes;
    }

    public static class Builder<T extends SequenceElement> {
        protected int vectorLength = 100;
        protected boolean useAdaGrad = false;
        protected double lr = 0.025;
        protected Random gen = Nd4j.getRandom();
        protected long seed = 123;
        protected double negative = 0;
        protected VocabCache<T> vocabCache;
        protected boolean useHS = true;

        public Builder<T> useHierarchicSoftmax(boolean reallyUse) {
            this.useHS = reallyUse;
            return this;
        }

        public Builder<T> cache(@NonNull VocabCache<T> vocab) {
            this.vocabCache = vocab;
            return this;
        }

        public Builder<T> negative(double negative) {
            this.negative = negative;
            return this;
        }

        public Builder<T> vectorLength(int vectorLength) {
            this.vectorLength = vectorLength;
            return this;
        }

        public Builder<T> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        /**
         * This method is deprecated, since all logic was pulled out from this class
         * @param lr
         * @return
         */
        @Deprecated
        public Builder<T> lr(double lr) {
            this.lr = lr;
            return this;
        }

        public Builder<T> gen(Random gen) {
            this.gen = gen;
            return this;
        }

        public Builder<T> seed(long seed) {
            this.seed = seed;
            return this;
        }

        public InMemoryLookupTable<T> build() {
            if (vocabCache == null)
                throw new IllegalStateException("Vocab cache must be specified");

            InMemoryLookupTable<T> table = new InMemoryLookupTable<>(vocabCache, vectorLength, useAdaGrad, lr, gen,
                    negative, useHS);
            table.seed = seed;

            return table;
        }
    }

    @Override
    public String toString() {
        return "InMemoryLookupTable{" + "syn0=" + syn0 + ", syn1=" + syn1 + ", vectorLength=" + vectorLength
                + ", rng=" + rng + ", lr=" + lr + ", expTable=" + Arrays.toString(expTable) + ", seed=" + seed
                + ", table=" + table + ", syn1Neg=" + syn1Neg + ", useAdaGrad=" + useAdaGrad + ", negative="
                + negative + ", vocab=" + vocab + ", codes=" + codes + '}';
    }

    /**
     * This method consumes weights of a given InMemoryLookupTable
     *
     * PLEASE NOTE: this method explicitly resets current weights
     *
     * @param srcTable
     */
    public void consume(InMemoryLookupTable<T> srcTable) {
        if (srcTable.vectorLength != this.vectorLength)
            throw new IllegalStateException("You can't consume lookupTable with different vector lengths");

        if (srcTable.syn0 == null)
            throw new IllegalStateException("Source lookupTable Syn0 is NULL");

        this.resetWeights(true);

        AtomicInteger cntHs = new AtomicInteger(0);
        AtomicInteger cntNg = new AtomicInteger(0);

        if (srcTable.syn0.rows() > this.syn0.rows())
            throw new IllegalStateException(
                    "You can't consume lookupTable with built for larger vocabulary without updating your vocabulary first");

        for (int x = 0; x < srcTable.syn0.rows(); x++) {
            this.syn0.putRow(x, srcTable.syn0.getRow(x));

            if (this.syn1 != null && srcTable.syn1 != null)
                this.syn1.putRow(x, srcTable.syn1.getRow(x));
            else if (cntHs.incrementAndGet() == 1)
                log.info("Skipping syn1 merge");

            if (this.syn1Neg != null && srcTable.syn1Neg != null) {
                this.syn1Neg.putRow(x, srcTable.syn1Neg.getRow(x));
            } else if (cntNg.incrementAndGet() == 1)
                log.info("Skipping syn1Neg merge");

            if (cntHs.get() > 0 && cntNg.get() > 0)
                throw new ND4JIllegalStateException("srcTable has no syn1/syn1neg");
        }
    }
}