Java tutorial
/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.models.word2vec; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import com.google.common.util.concurrent.AtomicDouble; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.List; /** * Intermediate layers of the neural network * * @author Adam Gibson */ public class VocabWord implements Comparable<VocabWord>, Serializable { private static final long serialVersionUID = 2223750736522624256L; //used in comparison when building the huffman tree private AtomicDouble wordFrequency = new AtomicDouble(0); private int index = -1; private List<Integer> codes = new ArrayList<>(); //for my sanity private String word; private INDArray historicalGradient; private List<Integer> points = new ArrayList<>(); private int codeLength = 0; public static VocabWord none() { return new VocabWord(0, "none"); } /** * * @param wordFrequency count of the word */ public VocabWord(double wordFrequency, String word) { this.wordFrequency.set(wordFrequency); if (word == null || word.isEmpty()) throw new IllegalArgumentException("Word must not be null or empty"); this.word = word; } public VocabWord() { } public void write(DataOutputStream dos) throws IOException { dos.writeDouble(wordFrequency.get()); } public VocabWord read(DataInputStream dos) throws IOException { this.wordFrequency.set(dos.readDouble()); return this; } public String getWord() { return word; } public void setWord(String word) { this.word = word; } public void increment() { increment(1); } public void increment(int by) { wordFrequency.getAndAdd(by); } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getWordFrequency() { if (wordFrequency == null) return 0.0; return wordFrequency.get(); } public List<Integer> getCodes() { return codes; } public void setCodes(List<Integer> codes) { this.codes = codes; } @Override public int compareTo(VocabWord o) { return Double.compare(wordFrequency.get(), o.wordFrequency.get()); } public double getGradient(int index, double g) { if (historicalGradient == null) { historicalGradient = Nd4j.zeros(getCodes().size()); } double pow = Math.pow(g, 2); historicalGradient.putScalar(index, historicalGradient.getDouble(index) + pow); double sqrt = FastMath.sqrt(historicalGradient.getDouble(index)); double abs = FastMath.abs(g) / (sqrt + 1e-6f); double ret = abs * 1e-1f; return ret; } public List<Integer> getPoints() { return points; } public void setPoints(List<Integer> points) { this.points = points; } public int getCodeLength() { return codeLength; } public void setCodeLength(int codeLength) { this.codeLength = codeLength; if (codes.size() < codeLength) { for (int i = 0; i < codeLength; i++) codes.add(0); } if (points.size() < codeLength) { for (int i = 0; i < codeLength; i++) points.add(0); } } @Override public boolean equals(Object o) { if (this == o) return true; VocabWord vocabWord = (VocabWord) o; if (codeLength != vocabWord.codeLength) return false; if (index != vocabWord.index) return false; if (!codes.equals(vocabWord.codes)) return false; if (historicalGradient != null ? !historicalGradient.equals(vocabWord.historicalGradient) : vocabWord.historicalGradient != null) return false; if (!points.equals(vocabWord.points)) return false; if (!word.equals(vocabWord.word)) return false; return wordFrequency.get() == vocabWord.wordFrequency.get(); } @Override public int hashCode() { int result = wordFrequency.hashCode(); result = 31 * result + index; result = 31 * result + codes.hashCode(); result = 31 * result + word.hashCode(); result = 31 * result + (historicalGradient != null ? historicalGradient.hashCode() : 0); result = 31 * result + points.hashCode(); result = 31 * result + codeLength; return result; } @Override public String toString() { return "VocabWord{" + "wordFrequency=" + wordFrequency + ", index=" + index + ", codes=" + codes + ", word='" + word + '\'' + ", historicalGradient=" + historicalGradient + ", points=" + points + ", codeLength=" + codeLength + '}'; } }