Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package com.anhth12.word2vec; import com.google.common.base.Strings; import com.google.common.util.concurrent.AtomicDouble; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Objects; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; /** * * @author anhth12 */ public class VocabWord implements Serializable, Comparable<VocabWord> { private AtomicDouble wordFrequency = new AtomicDouble(0); private int index = -1; private List<Integer> codes = new ArrayList<>(); private String word; private INDArray historicalGradient; private List<Integer> points = new ArrayList<>(); private int codeLength = 0; public VocabWord(double wordFrequency, String word) { this.wordFrequency.set(wordFrequency); if (Strings.isNullOrEmpty(word)) { throw new IllegalArgumentException("Word must not be null or empty"); } this.word = word; } public static VocabWord none() { return new VocabWord(0, "none"); } public VocabWord() { } public void write(DataOutputStream out) throws IOException { out.writeDouble(this.wordFrequency.get()); } public VocabWord read(DataInputStream dis) throws IOException { this.wordFrequency.set(dis.readDouble()); return this; } public String getWord() { return this.word; } public void setWord(String word) { this.word = word; } public void increment() { increment(1); } public void increment(int by) { wordFrequency.getAndAdd(by); } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getWordFrequency() { if (wordFrequency == null) { return 0.0; } return wordFrequency.get(); } public List<Integer> getCodes() { return codes; } @Override public int compareTo(VocabWord o) { return Double.compare(wordFrequency.get(), o.getWordFrequency()); } /** * Added in ot history of Gradient (x+g^2) * * return 0.1*|g|/(sqrt(x+g^2) + 10^-6 * I don't get this * @param index * @param g * @return */ public double getGradient(int index, int g) { if (historicalGradient == null) { historicalGradient = Nd4j.zeros(codes.size()); } double pow = Math.pow(g, 2); historicalGradient = historicalGradient.putScalar(index, historicalGradient.getDouble(index) + pow); double sqrt = FastMath.sqrt(historicalGradient.getDouble(index)); double abs = FastMath.abs(g) / (sqrt + 1e-6f); double ret = abs * 1e-1f; return ret; } public List<Integer> getPoints() { return points; } public void setPoints(List<Integer> points) { this.points = points; } public int getCodeLength() { return codeLength; } public void setCodeLength(int codeLength) { this.codeLength = codeLength; if (codes.size() < codeLength) { for (int i = 0; i < codeLength; i++) { codes.add(0); } } if (points.size() < codeLength) { for (int i = 0; i < codeLength; i++) { points.add(0); } } } @Override public boolean equals(Object o) { if (this == o) { return true; } if (!(o instanceof VocabWord)) { return false; } VocabWord vocabWord = (VocabWord) o; if (codeLength != vocabWord.codeLength) { return false; } if (index != vocabWord.index) { return false; } if (!codes.equals(vocabWord.codes)) { return false; } if (historicalGradient != null ? !historicalGradient.equals(vocabWord.historicalGradient) : vocabWord.historicalGradient != null) { return false; } if (!points.equals(vocabWord.points)) { return false; } if (!word.equals(vocabWord.word)) { return false; } if (!wordFrequency.equals(vocabWord.wordFrequency)) { return false; } return true; } @Override public int hashCode() { int hash = 7; hash = 41 * hash + Objects.hashCode(this.wordFrequency); hash = 41 * hash + this.index; hash = 41 * hash + Objects.hashCode(this.codes); hash = 41 * hash + Objects.hashCode(this.word); hash = 41 * hash + Objects.hashCode(this.historicalGradient); hash = 41 * hash + Objects.hashCode(this.points); hash = 41 * hash + this.codeLength; return hash; } @Override public String toString() { return "VocabWord{" + "wordFrequency=" + wordFrequency + ", index=" + index + ", codes=" + codes + ", word='" + word + '\'' + ", historicalGradient=" + historicalGradient + ", points=" + points + ", codeLength=" + codeLength + '}'; } }