com.cyc.tool.distributedrepresentations.Word2VecSpaceFromFile.java Source code

Java tutorial

Introduction

Here is the source code for com.cyc.tool.distributedrepresentations.Word2VecSpaceFromFile.java

Source

package com.cyc.tool.distributedrepresentations;

/*
 * #%L
 * DistributedRepresentations
 * %%
 * Copyright (C) 2015 Cycorp, Inc
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.zip.GZIPInputStream;
import org.apache.commons.io.EndianUtils;
import org.mapdb.DBMaker;

/**
 * Word2Vec distributed representation space from Google Format file.
 *
 * <p>
 * This class represents any distributed represenation computed using word2vec and initially loaded
 * from a Google word2vec formatted file
 */
public abstract class Word2VecSpaceFromFile extends Word2VecSpace {

    final StringBuilder sb = new StringBuilder();

    /**
     * Constructor for Word2VecSpaceFromFile
     * 
     * @throws IOException
     */
    public Word2VecSpaceFromFile() throws IOException {
        db = DBMaker.newFileDB(new File(Config.getW2vDBFile())).closeOnJvmShutdown()
                //      .encryptionEnable("password")
                .make();

    }

    /**
     * Create a W2V space in a DB.
     * 
     * @param w2vZipFile
     * @throws FileNotFoundException
     * @throws IOException
     */
    protected final void createW2VinDB(String w2vZipFile) throws FileNotFoundException, IOException {
        try (DataInputStream data_in = new DataInputStream(
                new GZIPInputStream(new FileInputStream(new File(w2vZipFile))))) {
            getWordsAndSize(data_in);
            if (vectors.size() == words) {
                System.out.println("Word2Vec is in DB");
            } else {
                System.out.println("DB Size:" + vectors.size());

                System.out.println("Want to read Word Count: " + words);
                System.out.println("Size:" + getSize());
                for (int w = 0; w < words; w++) {
                    float[] v = new float[getSize()];
                    String key = getVocabString(data_in);
                    System.out.println(w + ":\t" + key);

                    IntStream.range(0, getSize()).forEach(i -> v[i] = getFloat(data_in));
                    vectors.put(key, normVector(v));
                    if (w % 100000 == 1) {
                        db.commit();
                    }
                }
                db.commit();
                db.compact();
            }
        }
    }

    private float getFloat(DataInputStream s) {
        try {
            float v = EndianUtils.readSwappedFloat(s);
            //System.out.println(st+"["+i+"]: "+v);
            return v;
        } catch (IOException ex) {
            Logger.getLogger(Word2VecSpace.class.getName()).log(Level.SEVERE, null, ex);
            return 0.0f;
        }
    }

    private String getVocabString(DataInputStream s) throws IOException {
        sb.setLength(0);
        for (char ch = (char) s.read(); (!Character.isWhitespace(ch) && ch >= 0
                && ch <= 256); ch = (char) s.read()) {
            sb.append((char) ch);
        }
        return sb.toString();
    }

    private void getWordsAndSize(DataInputStream s) throws IOException {
        sb.setLength(0);
        for (char ch = (char) s.read(); ch != '\n'; ch = (char) s.read()) {
            sb.append(ch);
        }
        String[] parts = sb.toString().split("\\s+");
        words = Long.parseLong(parts[0]);
        setSize((int) Long.parseLong(parts[1]));
    }

}