w2v.WordToVec.java Source code

Java tutorial

Introduction

Here is the source code for w2v.WordToVec.java

Source

/** Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package w2v;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.PosixParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WordToVec {
    private static final Logger log = LoggerFactory.getLogger(WordToVec.class);

    class VocabWord implements Comparable<VocabWord> {
        VocabWord(String word) {
            this.word = word;
        }

        int cn = 0;
        int codelen;
        int[] point = new int[MAX_CODE_LENGTH];
        long[] code = new long[MAX_CODE_LENGTH];
        String word;

        @Override
        public int compareTo(VocabWord that) {
            if (that == null) {
                return 1;
            }

            return that.cn - this.cn;
        }

        @Override
        public String toString() {
            return this.cn + ": " + this.word;
        }
    }

    private static final int MAX_STRING = 100;
    private static final int EXP_TABLE_SIZE = 1000;
    private static final int MAX_EXP = 6;
    private static final int MAX_SENTENCE_LENGTH = 1000;
    private static final int MAX_CODE_LENGTH = 40;
    private static final int TABLE_SIZE = 100000000;

    // Maximum 30 * 0.7 = 21M words in the vocabulary
    private static final int VOCAB_HASH_SIZE = 30000000;

    private final int layerOneSize;
    private final File trainFile;
    private final File outputFile;
    private final File saveVocabFile;
    private final File readVocabFile;
    private final int window;
    private final int negative;
    private final int minCount;
    private final int numThreads;
    private final int classes;
    private final boolean binary;
    private final boolean cbow;
    private final boolean noHs;
    private final float startingAlpha;
    private final float sample;
    private final float[] expTable;

    private int minReduce = 1;
    private int vocabMaxSize = 1000;
    private VocabWord[] vocabWords = new VocabWord[vocabMaxSize];
    private int[] vocabHash = new int[VOCAB_HASH_SIZE];
    private Byte ungetc = null;

    private int vocabSize = 0;
    private long trainWords = 0;
    private long wordCountActual = 0;
    private int[] table;

    private float alpha;

    private float[] syn0;
    private float[] syn1;
    private float[] syn1neg;

    private long start;

    public WordToVec(Builder b) {
        this.trainFile = b.trainFile;
        this.outputFile = b.outputFile;
        this.saveVocabFile = b.saveVocabFile;
        this.readVocabFile = b.readVocabFile;
        this.binary = b.binary;
        this.cbow = b.cbow;
        this.noHs = b.noHs;
        this.startingAlpha = b.startingAlpha;
        this.sample = b.sample;
        this.window = b.window;
        this.negative = b.negative;
        this.minCount = b.minCount;
        this.numThreads = b.numThreads;
        this.classes = b.classes;
        this.layerOneSize = b.layerOneSize;

        float[] tempExpTable = new float[EXP_TABLE_SIZE];
        for (int i = 0; i < tempExpTable.length; i++) {
            // Precompute the exp() table
            tempExpTable[i] = (float) Math.exp((i / (float) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP);
            // Precompute f(x) = x / (x + 1)
            tempExpTable[i] = tempExpTable[i] / (tempExpTable[i] + 1);
        }
        expTable = tempExpTable;
    }

    private void readVocab() throws IOException {
        vocabSize = 0;
        try (DataInputStream is = new DataInputStream(new FileInputStream(readVocabFile))) {
            String word;
            while ((word = readWord(is)) != null) {
                int a = addWordToVocab(word);
                vocabWords[a].cn = is.readInt();
                is.readChar();
            }
            sortVocab();
            log.debug("Vocab size: {}", vocabSize);
            log.debug("Words in train file: {}", trainWords);
        } catch (IOException ioe) {
            throw ioe;
        }
    }

    private void learnVocabFromTrainFile() throws IOException {
        for (int a = 0; a < VOCAB_HASH_SIZE; a++) {
            vocabHash[a] = -1;
        }
        vocabSize = 0;
        addWordToVocab("</s>");
        try (DataInputStream is = new DataInputStream(new FileInputStream(trainFile))) {
            while (true) {
                String word = readWord(is);
                if (word == null) {
                    break;
                }
                trainWords++;
                if (log.isTraceEnabled() && trainWords % 100000 == 0) {
                    log.trace("{}K training words processed.", (trainWords / 1000));
                }
                int i = searchVocab(word);
                if (i == -1) {
                    i = addWordToVocab(word);
                    vocabWords[i].cn = 1;
                } else
                    vocabWords[i].cn++;
                if (vocabSize > VOCAB_HASH_SIZE * 0.7) {
                    reduceVocab();
                }
            }
        } catch (IOException ioe) {
            throw ioe;
        }
        sortVocab();
        log.debug("Vocab size: {}", vocabSize);
        log.debug("Words in train file: {}", trainWords);
    }

    private void saveVocab() throws IOException {
        saveVocabFile.delete();
        try (FileWriter fw = new FileWriter(saveVocabFile)) {
            //Don't output the </s>, at element zero.
            for (int i = 1; i < vocabSize; i++) {
                fw.write(vocabWords[i].word);
                fw.write(" ");
                fw.write("" + vocabWords[i].cn);
                fw.write("\n");
            }
        }
    }

    private void initNet() {
        syn0 = new float[vocabSize * layerOneSize];
        if (!noHs) {
            syn1 = new float[vocabSize * layerOneSize];
            for (int b = 0; b < layerOneSize; b++) {
                for (int a = 0; a < vocabSize; a++) {
                    syn1[a * layerOneSize + b] = 0;
                }
            }
        }
        if (negative > 0) {
            syn1neg = new float[vocabSize * layerOneSize];
            for (int b = 0; b < layerOneSize; b++) {
                for (int a = 0; a < vocabSize; a++) {
                    syn1neg[a * layerOneSize + b] = 0;
                }
            }
        }
        for (int b = 0; b < layerOneSize; b++) {
            for (int a = 0; a < vocabSize; a++) {
                syn0[a * layerOneSize + b] = (float) (Math.random() - 0.5) / layerOneSize;
            }
        }
        createBinaryTree();
    }

    // Create binary Huffman tree using the word counts
    // Frequent words will have short uniqe binary codes
    private void createBinaryTree() {
        //TODO: vocabSize.length cannot be longer than 1.2b .  Maybe use 2 arrays to allow this to be 2.4b?
        long[] count = new long[vocabSize * 2 + 1];
        long[] binary = new long[vocabSize * 2 + 1];
        int[] parentNode = new int[vocabSize * 2 + 1];
        for (int a = 0; a < vocabSize; a++) {
            count[a] = vocabWords[a].cn;
        }
        for (int a = vocabSize; a < vocabSize * 2; a++) {
            count[a] = 1_000_000_000_000_000L; //1e15
        }
        int pos1 = vocabSize - 1;
        int pos2 = vocabSize;
        int min1i;
        int min2i;

        // Following algorithm constructs the Huffman tree by adding one node at a time
        for (int a = 0; a < vocabSize - 1; a++) {
            // First, find two smallest nodes 'min1, min2'
            if (pos1 >= 0) {
                if (count[pos1] < count[pos2]) {
                    min1i = pos1;
                    pos1--;
                } else {
                    min1i = pos2;
                    pos2++;
                }
            } else {
                min1i = pos2;
                pos2++;
            }
            if (pos1 >= 0) {
                if (count[pos1] < count[pos2]) {
                    min2i = pos1;
                    pos1--;
                } else {
                    min2i = pos2;
                    pos2++;
                }
            } else {
                min2i = pos2;
                pos2++;
            }
            count[vocabSize + a] = count[min1i] + count[min2i];
            parentNode[min1i] = vocabSize + a;
            parentNode[min2i] = vocabSize + a;
            binary[min2i] = 1;
        }

        // Now assign binary code to each vocabulary word
        long[] code = new long[MAX_CODE_LENGTH];
        int[] point = new int[MAX_CODE_LENGTH];
        for (int a = 0; a < vocabSize; a++) {
            int b = a;
            int i = 0;
            while (true) {
                code[i] = binary[b];
                point[i] = b;
                i++;
                b = parentNode[b];
                if (b == vocabSize * 2 - 2)
                    break;
            }
            vocabWords[a].codelen = i;
            vocabWords[a].point[0] = vocabSize - 2;
            for (b = 0; b < i; b++) {
                vocabWords[a].code[i - b - 1] = code[b];
                vocabWords[a].point[i - b] = point[b] - vocabSize;
            }
        }
    }

    private void initUnigramTable() {
        long trainWordsPow = 0;
        float power = 0.75F;
        for (int a = 0; a < vocabSize; a++) {
            trainWordsPow += Math.pow(vocabWords[a].cn, power);
        }
        int i = 0;
        float d1 = (float) Math.pow(vocabWords[i].cn, power) / (float) trainWordsPow;
        for (int a = 0; a < TABLE_SIZE; a++) {
            table[a] = i;
            if (a / (float) TABLE_SIZE > d1) {
                i++;
                d1 += Math.pow(vocabWords[i].cn, power) / (float) trainWordsPow;
            }
            if (i >= vocabSize) {
                i = vocabSize - 1;
            }
        }
    }

    //DataOutputStream#writeFloat writes the high byte first
    //but let's write the low byte first to give ourselves a better chance of
    //compatibility with the original c++ code
    private void writeFloat(float f, DataOutputStream out) throws IOException {
        int v = Float.floatToIntBits(f);
        out.write((v >>> 0) & 0xFF);
        out.write((v >>> 8) & 0xFF);
        out.write((v >>> 16) & 0xFF);
        out.write((v >>> 24) & 0xFF);
    }

    public void trainModel() {
        if (trainFile == null && readVocabFile == null) {
            throw new IllegalStateException("You must supply either a trainFile or a readVocabFile.");
        }
        alpha = startingAlpha;
        if (readVocabFile != null) {
            try {
                log.info("Reading vocabulary from file {}.", readVocabFile);
                readVocab();
            } catch (IOException ioe) {
                log.error("There was a problem reading the vocabulary file.", ioe);
                return;
            }
        } else {
            log.info("Starting training using file {}.", trainFile);
            try {
                learnVocabFromTrainFile();
            } catch (IOException ioe) {
                log.error("There was a problem reading the training file.", ioe);
                return;
            }
        }
        if (saveVocabFile != null) {
            try {
                saveVocab();
            } catch (IOException ioe) {
                log.error("There was a problem writing the vocabulary file.", ioe);
                return;
            }
        }
        if (outputFile == null) {
            return;
        }
        initNet();
        if (negative > 0) {
            initUnigramTable();
        }
        start = System.nanoTime();
        //TODO: theads
        try {
            trainModelThread(0);
        } catch (IOException ioe) {
            log.error("There was a problem reading the training file.", ioe);
            return;
        }
        outputFile.delete();
        NumberFormat vectorTextFormat = new DecimalFormat("#.######");
        try (DataOutputStream os = new DataOutputStream(new FileOutputStream(outputFile))) {
            if (classes == 0) {
                // Save the word vectors
                os.writeBytes("" + vocabSize + " " + layerOneSize + "\n");
                for (int a = 0; a < vocabSize; a++) {
                    os.writeBytes(vocabWords[a].word);
                    os.writeBytes(" ");
                    if (binary) {
                        for (int b = 0; b < layerOneSize; b++) {
                            writeFloat(syn0[a * layerOneSize + b], os);
                        }
                    } else {
                        for (int b = 0; b < layerOneSize; b++) {
                            int index = a * layerOneSize + b;
                            float value = syn0[index];
                            os.writeBytes(vectorTextFormat.format(value) + " ");
                        }
                    }
                    os.writeBytes("\n");

                }
                os.writeBytes("\n");
            } else {
                // Run K-means on the word vectors
                if (classes * layerOneSize > Integer.MAX_VALUE) {
                    throw new RuntimeException(
                            "Number of classes times the size of Layer One cannot be greater than "
                                    + Integer.MAX_VALUE + " (" + classes + " * " + layerOneSize + ")");
                }
                int[] cl = new int[vocabSize];
                float[] cent = new float[classes * layerOneSize];
                int[] centcn = new int[classes];
                int numIterations = 10;

                for (int a = 0; a < vocabSize; a++) {
                    cl[a] = a % classes;
                }
                for (int a = 0; a < numIterations; a++) {
                    for (int b = 0; b < (classes * layerOneSize); b++) {
                        cent[b] = 0;
                    }
                    for (int b = 0; b < classes; b++) {
                        centcn[b] = 1;
                    }
                    for (int c = 0; c < vocabSize; c++) {
                        for (int d = 0; d < layerOneSize; d++) {
                            cent[layerOneSize * cl[c] + d] += syn0[c * layerOneSize + d];
                        }
                        centcn[cl[c]]++;
                    }
                    for (int b = 0; b < classes; b++) {
                        float closev = 0;
                        for (int c = 0; c < layerOneSize; c++) {
                            cent[layerOneSize * b + c] /= centcn[b];
                            closev += cent[layerOneSize * b + c] * cent[layerOneSize * b + c]; //TODO: ^2 ??
                        }
                        closev = (float) Math.sqrt(closev);
                        for (int c = 0; c < layerOneSize; c++) {
                            cent[layerOneSize * b + c] /= closev;
                        }
                    }
                    for (int c = 0; c < vocabSize; c++) {
                        float closev = -10;
                        int closeid = 0;
                        for (int d = 0; d < classes; d++) {
                            float x = 0;
                            for (int b = 0; b < layerOneSize; b++) {
                                x += cent[layerOneSize * d * b] * syn0[c * layerOneSize + b];
                            }
                            if (x > closev) {
                                closev = x;
                                closeid = d;
                            }
                        }
                        cl[c] = closeid;
                    }
                }
                // Save the K-means classes
                for (int a = 0; a < vocabSize; a++) {
                    os.writeBytes(vocabWords[a].word);
                    os.writeBytes(" ");
                    os.writeInt(cl[a]);
                }
            }
        } catch (IOException ioe) {
            log.error("There was a problem writing the output file", ioe);
            return;
        }
    }

    private void trainModelThread(int id) throws IOException {
        try (RandomAccessFile raf = new RandomAccessFile(trainFile, "rw")) {
            if (id > 0) {
                raf.seek(raf.length() / (numThreads * id));
            }
            long wordCount = 0;
            long lastWordCount = 0;
            int word = 0;
            int target = 0;
            int label = 0;
            int sentenceLength = 0;
            int sentencePosition = 0;
            int nextRandom = id;
            int[] sen = new int[MAX_SENTENCE_LENGTH + 1];
            float[] neu1 = new float[layerOneSize];
            float[] neu1e = new float[layerOneSize];

            NumberFormat alphaFormat = new DecimalFormat("0.000000");
            NumberFormat logPercentFormat = new DecimalFormat("#0.00%");
            NumberFormat wordsPerSecondFormat = new DecimalFormat("00.00k");
            long now = System.nanoTime();
            while (true) {
                if (wordCount - lastWordCount > 10000) {
                    wordCountActual += wordCount - lastWordCount;
                    lastWordCount = wordCount;
                    if (log.isTraceEnabled()) {
                        now = System.nanoTime();
                        log.trace("Alpha: {}", alphaFormat.format(alpha));
                        log.trace("Progress: {} ",
                                logPercentFormat.format((float) wordCountActual / (trainWords + 1)));
                        log.trace("Words/thread/sec: {}\n", wordsPerSecondFormat
                                .format((float) wordCountActual / (float) (now - start + 1) * 1000000));
                    }
                    alpha = startingAlpha * (1 - wordCountActual / (float) (trainWords + 1));
                    if (alpha < startingAlpha * 0.0001F) {
                        alpha = startingAlpha * 0.0001F;
                    }
                }
                if (sentenceLength == 0) {
                    while (true) {
                        word = readWordIndex(raf);
                        if (word == -1) {
                            break;
                        }
                        wordCount++;
                        if (word == 0) {
                            break;
                        }
                        // The subsampling randomly discards frequent words while keeping the ranking same
                        if (sample > 0) {
                            float ran = (float) (Math.sqrt(vocabWords[word].cn / (sample * trainWords)) + 1)
                                    * (sample * trainWords) / vocabWords[word].cn;
                            nextRandom = (int) (nextRandom * 25214903917L + 11);
                            if (ran < ((nextRandom & 0xFFFF) / (float) 65536)) {
                                continue;
                            }
                        }
                        sen[sentenceLength] = word;
                        sentenceLength++;
                        if (sentenceLength >= MAX_SENTENCE_LENGTH) {
                            break;
                        }
                    }
                    sentencePosition = 0;
                }
                if (raf.getFilePointer() == raf.length()) {
                    break;
                }
                if (wordCount > trainWords / numThreads) {
                    break;
                }
                word = sen[sentencePosition];
                for (int c = 0; c < layerOneSize; c++) {
                    neu1[c] = 0;
                    neu1e[c] = 0;
                }
                nextRandom = (int) (nextRandom * 25214903917L + 11);
                int b = nextRandom % window;
                if (cbow) {
                    // in -> hidden
                    for (int a = b; b < window * 2 + 1 - b; a++) {
                        if (a != window) {
                            int c = sentencePosition - window + a;
                            if (c < 0) {
                                continue;
                            }
                            if (c >= sentenceLength) {
                                continue;
                            }
                            int lastWord = sen[c];
                            for (c = 0; c < layerOneSize; c++) {
                                neu1[c] += syn0[c + lastWord * layerOneSize];
                            }
                        }
                    }
                    if (!noHs) {
                        for (int d = 0; d < vocabWords[word].codelen; d++) {
                            float f = 0;
                            int l2 = vocabWords[word].point[d] * layerOneSize;
                            // Propagate hidden -> output
                            for (int c = 0; c < layerOneSize; c++) {
                                f += neu1[c] * syn1[c + l2];
                            }
                            if (f <= -1 * MAX_EXP || f >= MAX_EXP) {
                                continue;
                            }
                            f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
                            // 'g' is the gradient multiplied by the learning rate
                            float g = (1 - vocabWords[word].code[d] - f) * alpha;
                            // Propagate errors output -> hidden
                            for (int c = 0; c < layerOneSize; c++) {
                                neu1e[c] += g * syn1[c + l2];
                            }
                            // Learn weights hidden -> output 
                            for (int c = 0; c < layerOneSize; c++) {
                                syn1[c + l2] += g * neu1[c];
                            }
                        }
                    }
                    // NEGATIVE SAMPLING
                    if (negative > 0) {
                        for (int d = 0; d < negative + 1; d++) {
                            if (d == 0) {
                                target = word;
                                label = 1;
                            } else {
                                nextRandom = (int) (nextRandom * 25214903917L + 11);
                                target = table[(nextRandom >> 16) % TABLE_SIZE];
                                if (target == 0) {
                                    target = nextRandom % (vocabSize - 1) + 1;
                                }
                                if (target == word) {
                                    continue;
                                }
                                label = 0;
                            }
                            int l2 = target * layerOneSize;
                            int f = 0;
                            for (int c = 0; c < layerOneSize; c++) {
                                f += neu1[c] * syn1neg[c + l2];
                            }
                            float g;
                            if (f > MAX_EXP) {
                                g = (label - 1) * alpha;
                            } else if (f < -MAX_EXP) {
                                g = (label - 0) * alpha;
                            } else {
                                g = (label - expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))])
                                        * alpha;
                            }
                            for (int c = 0; c < layerOneSize; c++) {
                                neu1e[c] += g * syn1neg[c + l2];
                            }
                            for (int c = 0; c < layerOneSize; c++) {
                                syn1neg[c + l2] += g * neu1[c];
                            }
                        }
                    }
                    // hidden -> in
                    for (int a = b; a < window * 2 + 1 - b; a++) {
                        if (a != window) {
                            int c = sentencePosition - window + a;
                            if (c < 0 || c >= sentenceLength) {
                                continue;
                            }
                            int lastWord = sen[c];
                            for (c = 0; c < layerOneSize; c++) {
                                syn0[c + lastWord * layerOneSize] += neu1e[c];
                            }
                        }
                    }
                } else { //train skip-gram
                    for (int a = b; a < window * 2 + 1 - b; a++) {
                        if (a != window) {
                            int lastWordIndex = sentencePosition - window + a;
                            if (lastWordIndex < 0 || lastWordIndex >= sentenceLength) {
                                continue;
                            }
                            int lastWord = sen[lastWordIndex];
                            int l1 = lastWord * layerOneSize;
                            for (int c = 0; c < layerOneSize; c++) {
                                neu1e[c] = 0;
                            }
                            // HIERARCHICAL SOFTMAX
                            if (!noHs) {
                                for (int d = 0; d < vocabWords[word].codelen; d++) {
                                    float f = 0;
                                    int l2 = vocabWords[word].point[d] * layerOneSize;
                                    // Propagate hidden -> output
                                    for (int c = 0; c < layerOneSize; c++) {
                                        f += syn0[c + l1] * syn1[c + l2];
                                    }
                                    if (f <= -MAX_EXP || f >= MAX_EXP) {
                                        continue;
                                    }
                                    f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
                                    // 'g' is the gradient multiplied by the learning rate
                                    float g = (1 - vocabWords[word].code[d] - f) * alpha;
                                    // Propagate errors output -> hidden
                                    for (int c = 0; c < layerOneSize; c++) {
                                        neu1e[c] += g * syn1[c + l2];
                                    }
                                    // Learn weights hidden -> output
                                    for (int c = 0; c < layerOneSize; c++) {
                                        syn1[c + l2] += g * syn0[c + l1];
                                    }
                                }
                            }
                            // NEGATIVE SAMPLING
                            if (negative > 0) {
                                for (int d = 0; d < negative + 1; d++) {
                                    if (d == 0) {
                                        target = word;
                                        label = 1;
                                    } else {
                                        nextRandom = (int) (nextRandom * 25214903917L + 11);
                                        target = table[(nextRandom >> 16) % TABLE_SIZE];
                                        if (target == 0) {
                                            target = nextRandom % (vocabSize - 1) + 1;
                                        }
                                        if (target == word) {
                                            continue;
                                        }
                                        label = 0;
                                    }
                                    int l2 = target * layerOneSize;
                                    int f = 0;
                                    for (int c = 0; c < layerOneSize; c++) {
                                        f += syn0[c + l1] * syn1neg[c + l2];
                                    }
                                    float g;
                                    if (f > MAX_EXP) {
                                        g = (label - 1) * alpha;
                                    } else if (f < -MAX_EXP) {
                                        g = (label - 0) * alpha;
                                    } else {
                                        g = (label
                                                - expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))])
                                                * alpha;
                                    }
                                    for (int c = 0; c < layerOneSize; c++) {
                                        neu1e[c] += g * syn1neg[c + l2];
                                    }
                                    for (int c = 0; c < layerOneSize; c++) {
                                        syn1neg[c + l2] += g * syn0[c + l1];
                                    }
                                }
                            }
                            // Learn weights input -> hidden
                            for (int c = 0; c < layerOneSize; c++) {
                                syn0[c + l1] += neu1e[c];
                            }
                        }
                    }
                }
                sentencePosition++;
                if (sentencePosition >= sentenceLength) {
                    sentenceLength = 0;
                    continue;
                }
            }
        } catch (IOException ioe) {
            throw ioe;
        }
    }

    // Reduces the vocabulary by removing infrequent tokens
    private void reduceVocab() {
        int b = 0;
        for (int a = 0; a < vocabSize; a++) {
            if (vocabWords[a].cn > minReduce) {
                vocabWords[b].cn = vocabWords[a].cn;
                vocabWords[b].word = vocabWords[a].word;
                b++;
            }
        }
        vocabSize = b;
        for (int a = 0; a < VOCAB_HASH_SIZE; a++) {
            vocabHash[a] = -1;
        }
        for (int a = 0; a < vocabSize; a++) {
            // Hash will be re-computed, as it is not actual
            int hash = getWordHash(vocabWords[a].word);
            while (vocabHash[hash] != -1) {
                hash = (hash + 1) % VOCAB_HASH_SIZE;
                hash = Math.abs(hash);
            }
            vocabHash[hash] = a;
        }
        minReduce++;
    }

    // Sorts the vocabulary by frequency using word counts
    private void sortVocab() {
        // Sort the vocabulary and keep </s> at the first position
        Arrays.sort(vocabWords, 1, vocabSize - 1);
        for (int a = 0; a < vocabHash.length; a++) {
            vocabHash[a] = -1;
        }

        trainWords = 0;
        int originalVocabSize = vocabSize;
        List<VocabWord> wordList = new ArrayList<VocabWord>(originalVocabSize);
        int aa = 0;
        for (int a = 0; a < originalVocabSize; a++) {
            VocabWord vw = vocabWords[a];
            // Words occurring less than min_count times will be discarded from the vocab
            if (vw.cn < minCount && vw.cn > 0) {
                vocabSize--;
            } else {
                // Hash will be re-computed, as after the sorting it is not actual
                int hash = getWordHash(vw.word);
                while (vocabHash[hash] != -1) {
                    hash = (hash + 1) % VOCAB_HASH_SIZE;
                    hash = Math.abs(hash);
                }
                vocabHash[hash] = aa;
                trainWords += vw.cn;
                wordList.add(vw);
                aa++;
            }
        }
        vocabWords = wordList.toArray(new VocabWord[wordList.size()]);
    }

    private int addWordToVocab(String word) {
        int length = word.length() + 1;
        if (length > MAX_STRING) {
            length = MAX_STRING;
        }
        vocabWords[vocabSize] = new VocabWord(word);
        vocabSize++;

        // Reallocate memory if needed
        if (vocabSize + 2 >= vocabMaxSize) {
            vocabMaxSize += 1000;
            VocabWord[] vocabWords1 = new VocabWord[vocabMaxSize];
            System.arraycopy(vocabWords, 0, vocabWords1, 0, vocabWords.length);
            vocabWords = vocabWords1;
        }
        int hash = getWordHash(word);
        while (vocabHash[hash] != -1) {
            hash = (hash + 1) % VOCAB_HASH_SIZE;
            hash = Math.abs(hash);
        }
        vocabHash[hash] = vocabSize - 1;
        return vocabSize - 1;
    }

    private int getWordHash(String word) {
        int hash = 0;
        for (int a = 0; a < word.length(); a++) {
            hash = hash * 257 + word.charAt(a);
        }
        hash = hash % VOCAB_HASH_SIZE;
        return Math.abs(hash);
    }

    // Returns position of a word in the vocabulary; if the word is not found, returns -1
    private int searchVocab(String word) {
        int hash = getWordHash(word);
        while (true) {
            if (vocabHash[hash] == -1) {
                return -1;
            }
            if (word.equals(vocabWords[vocabHash[hash]].word)) {
                return vocabHash[hash];
            }
            hash = (hash + 1) % VOCAB_HASH_SIZE;
            hash = Math.abs(hash);
        }
    }

    private int readWordIndex(RandomAccessFile raf) throws IOException {
        String word = readWord(raf);
        if (raf.length() == raf.getFilePointer()) {
            return -1;
        }
        return searchVocab(word);
    }

    private String readWord(DataInput dataInput) throws IOException {
        StringBuilder sb = new StringBuilder();
        while (true) {
            byte ch;
            if (ungetc != null) {
                ch = ungetc;
                ungetc = null;
            } else {
                try {
                    ch = dataInput.readByte();
                } catch (EOFException eofe) {
                    break;
                }
            }
            if (ch == '\r') {
                continue;
            }
            if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
                if (sb.length() > 0) {
                    if (ch == '\n') {
                        ungetc = ch;
                    }
                    break;
                }
                if (ch == '\n') {
                    return "</s>";
                } else {
                    continue;
                }
            }
            sb.append((char) ch);

            // Truncate too long words
            if (sb.length() >= MAX_STRING - 1) {
                sb.deleteCharAt(sb.length() - 1);
            }
        }
        String word = sb.length() == 0 ? null : sb.toString();
        return word;
    }

    public static class Builder {
        private File trainFile = null;
        private File outputFile = null;
        private File saveVocabFile = null;
        private File readVocabFile = null;
        private boolean binary = false;
        private boolean cbow = false;
        private boolean noHs = false;
        private float startingAlpha = 0.025F;
        private float sample = 0.0F;
        private int window = 5;
        private int negative = 0;
        private int minCount = 5;
        private int numThreads = 1;
        private int classes = 0;
        private int layerOneSize = 100;

        public Builder trainFile(String trainFile) {
            this.trainFile = new File(trainFile);
            return this;
        }

        public Builder outputFile(String outputFile) {
            this.outputFile = new File(outputFile);
            return this;
        }

        public Builder saveVocabFile(String saveVocabFile) {
            this.saveVocabFile = new File(saveVocabFile);
            return this;
        }

        public Builder readVocabFile(String readVocabFile) {
            this.readVocabFile = new File(readVocabFile);
            return this;
        }

        public Builder binary() {
            this.binary = true;
            return this;
        }

        public Builder cbow() {
            this.cbow = true;
            return this;
        }

        public Builder noHs() {
            this.noHs = true;
            return this;
        }

        public Builder startingAlpha(float startingAlpha) {
            this.startingAlpha = startingAlpha;
            return this;
        }

        public Builder sample(float sample) {
            this.sample = sample;
            return this;
        }

        public Builder window(int window) {
            this.window = window;
            return this;
        }

        public Builder negative(int negative) {
            this.negative = negative;
            return this;
        }

        public Builder minCount(int minCount) {
            this.minCount = minCount;
            return this;
        }

        public Builder numThreads(int numThreads) {
            this.numThreads = numThreads;
            return this;
        }

        public Builder classes(int classes) {
            this.classes = classes;
            return this;
        }

        public Builder layerOneSize(int layerOneSize) {
            this.layerOneSize = layerOneSize;
            return this;
        }
    }

    @SuppressWarnings("static-access")
    public static void main(String[] args) {
        Builder b = new Builder();
        Options options = new Options();
        options.addOption(OptionBuilder.hasArg().withArgName("file")
                .withDescription("Use text data from <file> to train the model").create("train"));
        options.addOption(OptionBuilder.hasArg().withArgName("file")
                .withDescription("Use <file> to save the resulting word vectors / word clusters").create("output"));
        options.addOption(OptionBuilder.hasArg().withArgName("int")
                .withDescription("Set size of word vectors; default is " + b.layerOneSize).create("size"));
        options.addOption(OptionBuilder.hasArg().withArgName("int")
                .withDescription("Set max skip length between words; default is " + b.window).create("window"));
        options.addOption(OptionBuilder.hasArg().withArgName("int").withDescription(
                "Set threshold for occurrence of words (0=off). Those that appear with higher frequency in the training data will be randomly down-sampled; default is "
                        + b.sample + ", useful value is 1e-5")
                .create("sample"));
        options.addOption(new Option("noHs", false,
                "Disable use of Hierarchical Softmax; " + (b.noHs ? "off" : "on") + " by default"));
        options.addOption(
                OptionBuilder.hasArg().withArgName("int").withDescription("Number of negative examples; default is "
                        + b.negative + ", common values are 5 - 10 (0 = not used)").create("negative"));
        options.addOption(OptionBuilder.hasArg().withArgName("int")
                .withDescription("Use <int> threads (default " + b.numThreads + ")").create("threads"));
        options.addOption(OptionBuilder.hasArg().withArgName("int")
                .withDescription(
                        "This will discard words that appear less than <int> times; default is " + b.minCount)
                .create("minCount"));
        options.addOption(OptionBuilder.hasArg().withArgName("float")
                .withDescription("Set the starting learning rate; default is " + b.startingAlpha)
                .create("startingAlpha"));
        options.addOption(OptionBuilder.hasArg().withArgName("int")
                .withDescription(
                        "Number of word classes to output, or 0 to output word vectors; default is " + b.classes)
                .create("classes"));
        options.addOption(new Option("binary", false,
                "Save the resulting vectors in binary moded; " + (b.binary ? "on" : "off") + " by default"));
        options.addOption(OptionBuilder.hasArg().withArgName("file")
                .withDescription("The vocabulary will be saved to <file>").create("saveVocab"));
        options.addOption(OptionBuilder.hasArg().withArgName("file")
                .withDescription("The vocabulary will be read from <file>, not constructed from the training data")
                .create("readVocab"));
        options.addOption(new Option("cbow", false, "Use the continuous bag of words model; "
                + (b.cbow ? "on" : "off") + " by default (skip-gram model)"));

        CommandLineParser parser = new PosixParser();
        try {
            CommandLine cl = parser.parse(options, args);
            if (cl.getOptions().length == 0) {
                new HelpFormatter().printHelp(WordToVec.class.getSimpleName(), options);
                System.exit(0);
            }
            if (cl.hasOption("size")) {
                b.layerOneSize = Integer.parseInt(cl.getOptionValue("size"));
            }
            if (cl.hasOption("train")) {
                b.trainFile = new File(cl.getOptionValue("train"));
            }
            if (cl.hasOption("saveVocab")) {
                b.saveVocabFile = new File(cl.getOptionValue("saveVocab"));
            }
            if (cl.hasOption("readVocab")) {
                b.readVocabFile = new File(cl.getOptionValue("readVocab"));
            }
            if (cl.hasOption("binary")) {
                b.binary = true;
            }
            if (cl.hasOption("cbow")) {
                b.cbow = true;
            }
            if (cl.hasOption("startingAlpha")) {
                b.startingAlpha = Float.parseFloat(cl.getOptionValue("startingAlpha"));
            }
            if (cl.hasOption("output")) {
                b.outputFile = new File(cl.getOptionValue("output"));
            }
            if (cl.hasOption("window")) {
                b.window = Integer.parseInt(cl.getOptionValue("window"));
            }
            if (cl.hasOption("sample")) {
                b.sample = Float.parseFloat(cl.getOptionValue("sample"));
            }
            if (cl.hasOption("noHs")) {
                b.noHs = true;
            }
            if (cl.hasOption("negative")) {
                b.negative = Integer.parseInt(cl.getOptionValue("negative"));
            }
            if (cl.hasOption("threads")) {
                b.numThreads = Integer.parseInt(cl.getOptionValue("threads"));
            }
            if (cl.hasOption("minCount")) {
                b.minCount = Integer.parseInt(cl.getOptionValue("minCount"));
            }
            if (cl.hasOption("classes")) {
                b.classes = Integer.parseInt(cl.getOptionValue("classes"));
            }
        } catch (Exception e) {
            System.err.println("Parsing command-line arguments failed. Reason: " + e.getMessage());
            new HelpFormatter().printHelp("word2vec", options);
            System.exit(1);
        }
        WordToVec word2vec = new WordToVec(b);
        word2vec.trainModel();
        System.exit(0);
    }
}