org.deeplearning4j.examples.multigpu.w2vsentiment.DataSetsBuilder.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.examples.multigpu.w2vsentiment.DataSetsBuilder.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.examples.multigpu.w2vsentiment;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;

import java.io.*;
import java.net.URL;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * This class generates binary datasets out of raw text for further use on gpu for faster tuning sessions.
 * Idea behind this approach is simple: if you're going to run over the same corpus over and over, it's more efficient
 * to do all preprocessing once, and save datasets as binary files, instead of doing the same stuff over and over on the fly
 *
 */
public class DataSetsBuilder {
    private static final Logger log = org.slf4j.LoggerFactory.getLogger(DataSetsBuilder.class);

    /** Data URL for downloading */
    public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
    /** Location to save and extract the training/testing data */
    public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"),
            "dl4j_w2vSentiment/");
    /** Location (local file system) for the Google News vectors. Set this manually. */
    //public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
    public static final String WORD_VECTORS_PATH = "/home/raver119/develop/GoogleNews-vectors-negative300.bin.gz";

    public static final String TRAIN_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"),
            "dl4j_w2vSentiment_train/");
    public static final String TEST_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"),
            "dl4j_w2vSentiment_test/");

    @Parameter(names = { "-b", "--batch" }, description = "BatchSize")
    private int batchSize = 64;

    @Parameter(names = { "-l", "--length" }, description = "Truncate max review length to")
    private int truncateReviewsToLength = 256;

    public void run(String[] args) throws Exception {
        JCommander jcmdr = new JCommander(this);
        try {
            jcmdr.parse(args);
        } catch (ParameterException e) {
            //User provides invalid input -> print the usage info
            jcmdr.usage();
            try {
                Thread.sleep(500);
            } catch (Exception e2) {
            }
            System.exit(1);
        }

        if (WORD_VECTORS_PATH.startsWith("/PATH/TO/YOUR/VECTORS/")) {
            throw new RuntimeException("Please set the WORD_VECTORS_PATH before running this example");
        }

        downloadData();

        WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
        SentimentExampleIterator train = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize,
                truncateReviewsToLength, true);
        SentimentExampleIterator test = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize,
                truncateReviewsToLength, false);

        log.info("Saving test data...");
        //        saveDatasets(test, TEST_PATH);

        log.info("Saving train data...");
        saveDatasets(train, TRAIN_PATH);
    }

    private static void downloadData() throws Exception {
        //Create directory if required
        File directory = new File(DATA_PATH);
        if (!directory.exists())
            directory.mkdir();

        //Download file:
        String archizePath = DATA_PATH + "aclImdb_v1.tar.gz";
        File archiveFile = new File(archizePath);
        String extractedPath = DATA_PATH + "aclImdb";
        File extractedFile = new File(extractedPath);

        if (!archiveFile.exists()) {
            System.out.println("Starting data download (80MB)...");
            FileUtils.copyURLToFile(new URL(DATA_URL), archiveFile);
            System.out.println("Data (.tar.gz file) downloaded to " + archiveFile.getAbsolutePath());
            //Extract tar.gz file to output directory
            extractTarGz(archizePath, DATA_PATH);
        } else {
            //Assume if archive (.tar.gz) exists, then data has already been extracted
            System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
            if (!extractedFile.exists()) {
                //Extract tar.gz file to output directory
                extractTarGz(archizePath, DATA_PATH);
            } else {
                System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
            }
        }
    }

    private static final int BUFFER_SIZE = 4096;

    private static void extractTarGz(String filePath, String outputPath) throws IOException {
        int fileCount = 0;
        int dirCount = 0;
        System.out.print("Extracting files");
        try (TarArchiveInputStream tais = new TarArchiveInputStream(
                new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(filePath))))) {
            TarArchiveEntry entry;

            /** Read the tar entries using the getNextEntry method **/
            while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
                //System.out.println("Extracting file: " + entry.getName());

                //Create directories as required
                if (entry.isDirectory()) {
                    new File(outputPath + entry.getName()).mkdirs();
                    dirCount++;
                } else {
                    int count;
                    byte data[] = new byte[BUFFER_SIZE];

                    FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
                    BufferedOutputStream dest = new BufferedOutputStream(fos, BUFFER_SIZE);
                    while ((count = tais.read(data, 0, BUFFER_SIZE)) != -1) {
                        dest.write(data, 0, count);
                    }
                    dest.close();
                    fileCount++;
                }
                if (fileCount % 1000 == 0)
                    System.out.print(".");
            }
        }

        System.out
                .println("\n" + fileCount + " files and " + dirCount + " directories extracted to: " + outputPath);
    }

    protected void saveDatasets(DataSetIterator iterator, String dir) {
        AtomicInteger counter = new AtomicInteger(0);
        new File(dir).mkdirs();
        while (iterator.hasNext()) {
            String path = FilenameUtils.concat(dir, "dataset-" + (counter.getAndIncrement()) + ".bin");
            iterator.next().save(new File(path));

            if (counter.get() % 500 == 0)
                log.info("{} datasets saved so far...", counter.get());
        }
    }

    public static void main(String[] args) throws Exception {
        new DataSetsBuilder().run(args);
    }
}