org.knime.ext.textprocessing.dl4j.util.WordVectorPortObjectUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.knime.ext.textprocessing.dl4j.util.WordVectorPortObjectUtils.java

Source

/*******************************************************************************
 * Copyright by KNIME AG, Zurich, Switzerland
 * Website: http://www.knime.com; Email: contact@knime.com
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License, Version 3, as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses>.
 *
 * Additional permission under GNU GPL version 3 section 7:
 *
 * KNIME interoperates with ECLIPSE solely via ECLIPSE's plug-in APIs.
 * Hence, KNIME and ECLIPSE are both independent programs and are not
 * derived from each other. Should, however, the interpretation of the
 * GNU GPL Version 3 ("License") under any applicable laws result in
 * KNIME and ECLIPSE being a combined program, KNIME AG herewith grants
 * you the additional permission to use and propagate KNIME together with
 * ECLIPSE with only the license terms in place for ECLIPSE applying to
 * ECLIPSE and the GNU GPL Version 3 applying for KNIME, provided the
 * license terms of ECLIPSE themselves allow for the respective use and
 * propagation of ECLIPSE together with KNIME.
 *
 * Additional permission relating to nodes for KNIME that extend the Node
 * Extension (and in particular that are based on subclasses of NodeModel,
 * NodeDialog, and NodeView) and that only interoperate with KNIME through
 * standard APIs ("Nodes"):
 * Nodes are deemed to be separate and independent programs and to not be
 * covered works.  Notwithstanding anything to the contrary in the
 * License, the License does not apply to Nodes, you are not required to
 * license Nodes under the License, and you are granted a license to
 * prepare and propagate Nodes, in each case even if such Nodes are
 * propagated with or for interoperation with KNIME.  The owner of a Node
 * may freely choose the license terms applicable to such Node, including
 * when such Node is propagated with or for interoperation with KNIME.
 *******************************************************************************/
package org.knime.ext.textprocessing.dl4j.util;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.UUID;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.knime.core.util.FileUtil;
import org.knime.ext.textprocessing.dl4j.nodes.embeddings.WordVectorFileStorePortObject;
import org.knime.ext.textprocessing.dl4j.nodes.embeddings.WordVectorPortObject;
import org.knime.ext.textprocessing.dl4j.nodes.embeddings.WordVectorPortObjectSpec;
import org.knime.ext.textprocessing.dl4j.settings.enumerate.WordVectorTrainingMode;

/**
 * Utility class for {@link WordVectorPortObject} and {@link WordVectorPortObjectSpec} Serialization. Also contains
 * utility methods to load/save members of port and spec class.
 *
 * @author David Kolb, KNIME.com GmbH
 */
public class WordVectorPortObjectUtils {

    private WordVectorPortObjectUtils() {
        // Utility class
    }

    /**
     * Serializes a {@link WordVectorPortObject} to zip using the specified {@link ZipOutputStream}. It can be specified
     * if both PortObject and Spec should be written or only one of them.
     *
     * @param portObject the {@link WordVectorPortObject} to write
     * @param writePortObject whether to write PortObject
     * @param writeSpec whether to write Spec
     * @param outStream stream to write to
     * @throws IOException
     */
    @Deprecated
    public static void saveModelToZip(final WordVectorPortObject portObject, final boolean writePortObject,
            final boolean writeSpec, final ZipOutputStream outStream) throws IOException {
        final WordVectorPortObjectSpec spec = (WordVectorPortObjectSpec) portObject.getSpec();

        if (outStream == null) {
            throw new IOException("OutputStream is null");
        }
        if (writePortObject && !writeSpec) {
            savePortObjectOnly(portObject, outStream);
        }
        if (!writePortObject && writeSpec) {
            saveSpecOnly(spec, outStream);
        }
        if (writePortObject && writeSpec) {
            savePortObjectAndSpec(portObject, spec, outStream);
        }
    }

    /**
     * Reads {@link WordVectorPortObject} from specified {@link ZipInputStream}.
     *
     * @param inStream stream to read from
     * @param mode the type of WordVector model to expect
     * @return WordVectorPortObject
     * @throws IOException
     */
    @Deprecated
    public static WordVectorPortObject loadPortFromZip(final ZipInputStream inStream,
            final WordVectorTrainingMode mode) throws IOException {
        return new WordVectorPortObject(loadWordVectors(inStream, mode), null);
    }

    /**
     * Reads {@link WordVectors} from the specified {@link ZipInputStream}. The method expects the ZipInputStream to
     * contain an ZipEntry with name "word_vectors", which contains the WordVector model to load.
     *
     * @param in stream to read from
     * @param mode the type of WordVector model to expect
     * @return {@link WordVectors} loaded from stream
     * @throws IOException
     */
    public static WordVectors loadWordVectors(final ZipInputStream in, final WordVectorTrainingMode mode)
            throws IOException {
        ZipEntry entry;
        while ((entry = in.getNextEntry()) != null) {
            if (entry.getName().matches("word_vectors")) {
                switch (mode) {
                case DOC2VEC:
                    return WordVectorSerializer.readParagraphVectors(in);
                case WORD2VEC:
                    /* Need to copy stream to temp file because API does not support Word2VecModel reading
                     * with InputStreams. */
                    Word2Vec model = null;
                    File tmp = null;
                    try {
                        tmp = copyInputStreamToTmpFile(in);
                        model = WordVectorSerializer.readWord2VecModel(tmp);
                    } catch (Exception e) {
                        throw e;
                    } finally {
                        if (tmp != null && tmp.exists()) {
                            tmp.delete();
                        }
                    }

                    return model;
                default:
                    throw new IllegalStateException(
                            "No deserialization method defined for WordVectors of type: " + mode);
                }
            }
        }
        throw new IllegalArgumentException(
                "WordVectors entry not found. ZipInputStream seems not to contain ZipEntry "
                        + "with name 'word_vectors'!");
    }

    /**
     * Reads {@link WordVectors} from the specified {@link URL}.
     *
     * @param url the URL to read from
     * @param mode the type of WordVector model to expect
     * @return {@link WordVectors} loaded from URL
     * @throws IOException
     * @throws URISyntaxException
     */
    public static WordVectors loadWordVectors(final URL url, final WordVectorTrainingMode mode)
            throws IOException, URISyntaxException {
        switch (mode) {
        case DOC2VEC:
            return WordVectorSerializer.readParagraphVectors(url.openStream());
        case WORD2VEC:
            boolean isLocalFile = FileUtil.resolveToPath(url) != null;
            File wvFile = null;
            try {
                wvFile = isLocalFile ? FileUtil.getFileFromURL(url) : copyURLToTmpFile(url);
                return WordVectorSerializer.readWord2VecModel(wvFile);
            } catch (Exception e) {
                // error is handled outside
                throw e;
            } finally {
                // file has been temporarily downloaded
                if (!isLocalFile && wvFile != null) {
                    wvFile.delete();
                }
            }

        default:
            throw new IllegalStateException("No deserialization method defined for WordVectors of type: " + mode);
        }
    }

    /**
     * Writes port object without spec. Calling this will close the stream.
     *
     * @param portObject
     * @param out
     * @throws IOException
     */
    @Deprecated
    private static void savePortObjectOnly(final WordVectorPortObject portObject, final ZipOutputStream out)
            throws IOException {
        writeWordVectors(portObject.getWordVectors(), out);
    }

    /**
     * Writes port spec without object.
     *
     * @param spec
     * @param out
     * @throws IOException
     */
    private static void saveSpecOnly(final WordVectorPortObjectSpec spec, final ZipOutputStream out)
            throws IOException {
        final WordVectorTrainingMode mode = spec.getWordVectorTrainingsMode();

        writeWordVectorTrainingsMode(mode, out);
    }

    /**
     * Writes port object and spec.
     *
     * @param portObject
     * @param spec
     * @param out
     * @throws IOException
     */
    @Deprecated
    private static void savePortObjectAndSpec(final WordVectorPortObject portObject,
            final WordVectorPortObjectSpec spec, final ZipOutputStream out) throws IOException {
        saveSpecOnly(spec, out);
        //writing of port object needs to be done at last in this case because the DL4J WordVectorSerializer closes the stream
        savePortObjectOnly(portObject, out);
    }

    /**
     * Writes {@link WordVectorTrainingMode} to specified {@link ZipInputStream} .
     *
     * @param mode the {@link WordVectorTrainingMode} to write
     * @param out stream to write to
     * @throws IOException
     */
    public static void writeWordVectorTrainingsMode(final WordVectorTrainingMode mode, final ZipOutputStream out)
            throws IOException {
        final ZipEntry entry = new ZipEntry("word_vector_trainings_mode");
        out.putNextEntry(entry);
        out.write(mode.toString().getBytes(Charset.forName("UTF-8")));
    }

    /**
     * Serializes the specified port object to the specified stream including port object spec.
     *
     * @param port
     * @param out
     * @throws IOException
     */
    public static void writeFileStorePortObject(final WordVectorFileStorePortObject port, final ZipOutputStream out)
            throws IOException {
        WordVectorPortObjectSpec spec = (WordVectorPortObjectSpec) port.getSpec();
        saveSpecOnly(spec, out);
        writeWordVectors(port.getWordVectors(), out);
    }

    /**
     * Writes {@link WordVectors} to specified {@link ZipInputStream}. Calling this method will close the stream (closed
     * by WordVectorSerializer).
     *
     * @param wordVectors {@link WordVectors} to write
     * @param out stream to write to
     * @throws IOException
     */
    public static void writeWordVectors(final WordVectors wordVectors, final ZipOutputStream out)
            throws IOException {
        final ZipEntry entry = new ZipEntry("word_vectors");
        out.putNextEntry(entry);
        //first check if ParagraphVectors because it extends Word2Vec
        if (wordVectors instanceof ParagraphVectors) {
            ParagraphVectors d2v = (ParagraphVectors) wordVectors;
            WordVectorSerializer.writeParagraphVectors(d2v, out);
        } else if (wordVectors instanceof Word2Vec) {
            Word2Vec w2v = (Word2Vec) wordVectors;
            WordVectorSerializer.writeWord2VecModel(w2v, out);
        } else {
            throw new IllegalStateException("No serialization method defined for WordVectors of type: "
                    + wordVectors.getClass().getSimpleName());
        }
    }

    /**
     * Copies the content of the specified InputStream to a temp file.
     *
     * @param is stream to copy
     * @return file pointing to stream content
     * @throws IOException
     */
    public static File copyInputStreamToTmpFile(final InputStream is) throws IOException {
        File tmpFile = FileUtil.createTempFile(UUID.randomUUID().toString(), null);
        FileUtils.copyInputStreamToFile(is, tmpFile);
        return tmpFile;
    }

    /**
     * Copies the content of the specified URL to a temp file.
     *
     * @param url
     * @return file pointing to URL content
     * @throws IOException
     */
    public static File copyURLToTmpFile(final URL url) throws IOException {
        File tmpFile = FileUtil.createTempFile(UUID.randomUUID().toString(), null);
        FileUtils.copyURLToFile(url, tmpFile);
        return tmpFile;
    }

    /**
     * Converts wordVectors to {@link Word2Vec}. Sets {@link WeightLookupTable} and {@link VocabCache}. Depending on
     * specified word vector type this may lead to information loss. E.g. labels for {@link ParagraphVectors}.
     *
     * @param wordVectors
     * @return Word2Vec containing vocab and lookup table
     */
    public static Word2Vec wordVectorsToWord2Vec(final WordVectors wordVectors) {
        final Word2Vec w2v = new Word2Vec();
        w2v.setLookupTable(wordVectors.lookupTable());
        w2v.setVocab(wordVectors.vocab());
        return w2v;
    }

}