com.anhth12.word2vec.VocabWord.java Source code

Java tutorial

Introduction

Here is the source code for com.anhth12.word2vec.VocabWord.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package com.anhth12.word2vec;

import com.google.common.base.Strings;
import com.google.common.util.concurrent.AtomicDouble;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/**
 *
 * @author anhth12
 */
public class VocabWord implements Serializable, Comparable<VocabWord> {

    private AtomicDouble wordFrequency = new AtomicDouble(0);

    private int index = -1;

    private List<Integer> codes = new ArrayList<>();

    private String word;
    private INDArray historicalGradient;
    private List<Integer> points = new ArrayList<>();
    private int codeLength = 0;

    public VocabWord(double wordFrequency, String word) {
        this.wordFrequency.set(wordFrequency);
        if (Strings.isNullOrEmpty(word)) {
            throw new IllegalArgumentException("Word must not be null or empty");
        }
        this.word = word;
    }

    public static VocabWord none() {
        return new VocabWord(0, "none");
    }

    public VocabWord() {
    }

    public void write(DataOutputStream out) throws IOException {
        out.writeDouble(this.wordFrequency.get());
    }

    public VocabWord read(DataInputStream dis) throws IOException {
        this.wordFrequency.set(dis.readDouble());
        return this;
    }

    public String getWord() {
        return this.word;
    }

    public void setWord(String word) {
        this.word = word;
    }

    public void increment() {
        increment(1);
    }

    public void increment(int by) {
        wordFrequency.getAndAdd(by);
    }

    public int getIndex() {
        return index;
    }

    public void setIndex(int index) {
        this.index = index;
    }

    public double getWordFrequency() {
        if (wordFrequency == null) {
            return 0.0;
        }
        return wordFrequency.get();
    }

    public List<Integer> getCodes() {
        return codes;
    }

    @Override
    public int compareTo(VocabWord o) {
        return Double.compare(wordFrequency.get(), o.getWordFrequency());
    }

    /**
     * Added in ot history of Gradient (x+g^2)
     * 
     * return 0.1*|g|/(sqrt(x+g^2) + 10^-6
     * I don't get this
     * @param index
     * @param g
     * @return 
     */
    public double getGradient(int index, int g) {
        if (historicalGradient == null) {
            historicalGradient = Nd4j.zeros(codes.size());
        }

        double pow = Math.pow(g, 2);
        historicalGradient = historicalGradient.putScalar(index, historicalGradient.getDouble(index) + pow);
        double sqrt = FastMath.sqrt(historicalGradient.getDouble(index));
        double abs = FastMath.abs(g) / (sqrt + 1e-6f);
        double ret = abs * 1e-1f;
        return ret;
    }

    public List<Integer> getPoints() {
        return points;
    }

    public void setPoints(List<Integer> points) {
        this.points = points;
    }

    public int getCodeLength() {
        return codeLength;
    }

    public void setCodeLength(int codeLength) {
        this.codeLength = codeLength;
        if (codes.size() < codeLength) {
            for (int i = 0; i < codeLength; i++) {
                codes.add(0);
            }
        }
        if (points.size() < codeLength) {
            for (int i = 0; i < codeLength; i++) {
                points.add(0);
            }
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof VocabWord)) {
            return false;
        }

        VocabWord vocabWord = (VocabWord) o;

        if (codeLength != vocabWord.codeLength) {
            return false;
        }
        if (index != vocabWord.index) {
            return false;
        }
        if (!codes.equals(vocabWord.codes)) {
            return false;
        }
        if (historicalGradient != null ? !historicalGradient.equals(vocabWord.historicalGradient)
                : vocabWord.historicalGradient != null) {
            return false;
        }
        if (!points.equals(vocabWord.points)) {
            return false;
        }
        if (!word.equals(vocabWord.word)) {
            return false;
        }
        if (!wordFrequency.equals(vocabWord.wordFrequency)) {
            return false;
        }

        return true;
    }

    @Override
    public int hashCode() {
        int hash = 7;
        hash = 41 * hash + Objects.hashCode(this.wordFrequency);
        hash = 41 * hash + this.index;
        hash = 41 * hash + Objects.hashCode(this.codes);
        hash = 41 * hash + Objects.hashCode(this.word);
        hash = 41 * hash + Objects.hashCode(this.historicalGradient);
        hash = 41 * hash + Objects.hashCode(this.points);
        hash = 41 * hash + this.codeLength;
        return hash;
    }

    @Override
    public String toString() {
        return "VocabWord{" + "wordFrequency=" + wordFrequency + ", index=" + index + ", codes=" + codes
                + ", word='" + word + '\'' + ", historicalGradient=" + historicalGradient + ", points=" + points
                + ", codeLength=" + codeLength + '}';
    }

}