org.deeplearning4j.patent.DownloadPreprocessPatents.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.patent.DownloadPreprocessPatents.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.patent;

import com.beust.jcommander.Parameter;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.*;
import org.apache.hadoop.fs.FileSystem;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.VoidFunction;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.patent.preprocessing.JSoupXmlParser;
import org.deeplearning4j.patent.preprocessing.LegacyFormatPatentParser;
import org.deeplearning4j.patent.preprocessing.Patent;
import org.deeplearning4j.patent.preprocessing.PatentLabelGenerator;
import org.deeplearning4j.patent.utils.JCommanderUtils;
import org.deeplearning4j.patent.utils.WordVectorProvider;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.jsoup.nodes.Element;
import org.jsoup.select.Elements;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.net.URI;
import java.net.URL;
import java.nio.file.Files;
import java.util.*;

/**
 * This Spark script downloads and preprocesses the patent files. Is is part 1 of 2 of this example.
 *
 * See PatentExampleReadme.md for more details.
 *
 * NOTE: It assumes you have an Azure storage account - it should be possible to adapt this for other cloud providers.
 * Run using the provided shell scripts.
 *
 * @author Alex Black
 */
public class DownloadPreprocessPatents {
    private static final Logger log = LoggerFactory.getLogger(DownloadPreprocessPatents.class);

    @Parameter(names = {
            "--sparkAppName" }, description = "App name for spark. Optional - can set it to anything to identify your job")
    private String sparkAppName = "DL4JPatentExamplePreprocessing";

    @Parameter(names = {
            "--azureStorageAcct" }, description = "Name of the Azure storage account to use for storage", required = true)
    private String azureStorageAcct;

    @Parameter(names = {
            "--azureContainerZips" }, description = "Name of the container in the specified storage account for the zip files", required = true)
    private String azureContainerZips;

    @Parameter(names = {
            "--azureContainerPreproc" }, description = "Name of the container in the specified storage account for the serialized training DataSet files")
    private String azureContainerPreproc = "patentPreprocData";

    @Parameter(names = { "--downloadZips" }, description = "Whether the zips should be downloaded", arity = 1)
    private boolean downloadZips = true;

    @Parameter(names = { "--dlFirstYear" }, description = "First year to downloaded")
    private int dlFirstYear = 1976;

    @Parameter(names = { "--dlLastYear" }, description = "Last year to download")
    private int dlLastYear = 2018;

    @Parameter(names = {
            "--generateTrainingData" }, description = "Whether the training data should be generated", arity = 1)
    private boolean generateTrainingData = true;

    @Parameter(names = {
            "--firstTestYear" }, description = "First test year. For example, if set to 2017, then all of 2017 and 2018 patent data will be the test set", arity = 1)
    private int firstTestYear = 2018;

    @Parameter(names = { "--minibatch" }, description = "Minibatch size for generated DataSets")
    private int minibatch = 32;

    @Parameter(names = {
            "--maxSequenceLength" }, description = "Maximum number of words in the sequences for generated DataSets")
    private int maxSequenceLength = 1000;

    @Parameter(names = { "--wordVectorsPath" })
    private String wordVectorsPath = "wasbs://resources@deeplearning4jblob.blob.core.windows.net/wordvectors/GoogleNews-vectors-negative300.bin.gz";

    public static void main(String[] args) throws Exception {
        new DownloadPreprocessPatents().entryPoint(args);
    }

    /**
     * JCommander entry point
     */
    protected void entryPoint(String[] args) throws Exception {
        JCommanderUtils.parseArgs(this, args);

        //Azure storage account naming rules: https://blogs.msdn.microsoft.com/jmstall/2014/06/12/azure-storage-naming-rules/
        //The default exceptions aren't helpful, we'll validate this here
        if (!azureStorageAcct.matches("^[a-z0-9]+$") || azureStorageAcct.length() < 3
                || azureStorageAcct.length() > 24) {
            throw new IllegalStateException("Invalid storage account name: must be alphanumeric, lowercase, "
                    + "3 to 24 characters. Got option azureStorageAcct=\"" + azureStorageAcct + "\"");
        }
        if (!azureContainerZips.matches("^[a-z0-9-]+$") || azureContainerZips.length() < 3
                || azureContainerZips.length() > 63) {
            throw new IllegalStateException(
                    "Invalid Azure container name: must be alphanumeric or dash, lowercase, "
                            + "3 to 63 characters. Got option azureContainerZips=\"" + azureContainerZips + "\"");
        }
        if (!azureContainerPreproc.matches("^[a-z0-9-]+$") || azureContainerPreproc.length() < 3
                || azureContainerPreproc.length() > 63) {
            throw new IllegalStateException(
                    "Invalid Azure container name: must be alphanumeric or dash, lowercase, "
                            + "3 to 63 characters. Got option azureContainerPreproc=\"" + azureContainerPreproc
                            + "\"");
        }

        SparkConf sparkConf = new SparkConf();
        sparkConf.setAppName(sparkAppName);
        JavaSparkContext sc = new JavaSparkContext();

        String baseDirZips = "wasbs://" + azureContainerZips + "@" + azureStorageAcct + ".blob.core.windows.net/";

        if (downloadZips) {
            log.info("*** Starting download of patent data ***");
            log.info("--- WARNING: This requires approximately 92GB of storage for the patents in zip format! ---"); //Approx 464GB uncompressed - but we process directly from zipped format
            long start = System.currentTimeMillis();
            List<String> downloadUrls = new ArrayList<>();
            String format = "https://bulkdata.uspto.gov/data/patent/grant/redbook/fulltext/";
            for (int i = dlFirstYear; i <= dlLastYear; i++) {
                String url = format + i + "/";
                downloadUrls.addAll(getZipUrlsFromPage(url));
            }

            //Make sure the container exists; create it if not
            Path p = new Path(URI.create(baseDirZips));
            FileSystem fs = FileSystem.get(URI.create(baseDirZips), new Configuration());
            if (!fs.exists(p)) {
                fs.create(p);
            }

            JavaRDD<String> rdd = sc.parallelize(downloadUrls);
            rdd.foreach(new DownloadToAzureFn(baseDirZips, true));
            long end = System.currentTimeMillis();
            log.info("*** Completed download of patent data in {} sec ***", (end - start) / 1000);
        }

        if (generateTrainingData) {
            log.info("*** Starting preprocessing of patent data ***");
            log.info("--- WARNING: This requires approximately 20 GB of storage for the processed data! ---");
            long start = System.currentTimeMillis();
            String dirName = "seqLength" + maxSequenceLength + "_mb" + minibatch;
            String containerRoot = "wasbs://" + azureContainerPreproc + "@" + azureStorageAcct
                    + ".blob.core.windows.net/";
            String baseOutPath = containerRoot + dirName;
            String baseOutputPathTrain = baseOutPath + "/train/";
            String baseOutputPathTest = baseOutPath + "/test/";

            //Seems that we need the container to exist before we can create anything in a subdirectory...
            Configuration config = new Configuration();
            FileSystem fs;
            try {
                fs = FileSystem.get(URI.create(containerRoot), config);
            } catch (Throwable t) {
                throw new RuntimeException("Error getting filesystem for container root: " + containerRoot, t);
            }
            if (!fs.exists(new Path(containerRoot))) {
                fs.create(new Path(containerRoot));
            }

            JavaRDD<String> pathsTrain = sc.parallelize(listPaths(1976, firstTestYear - 1, sc, baseDirZips));
            JavaRDD<long[]> rddTrain = pathsTrain.map(new PatentsToIndexFilesFunction(maxSequenceLength,
                    wordVectorsPath, minibatch, PatentLabelGenerator.classLabelToIndex(), baseOutputPathTrain, 10));

            log.info("Finished generating training datasets");

            long[] resultTrain = rddTrain.reduce(new ReduceArrayFunction());

            JavaRDD<String> pathsTest = sc.parallelize(listPaths(firstTestYear, 2018, sc, baseDirZips));
            JavaRDD<long[]> rddTest = pathsTest.map(new PatentsToIndexFilesFunction(maxSequenceLength,
                    wordVectorsPath, minibatch, PatentLabelGenerator.classLabelToIndex(), baseOutputPathTest, 10));
            log.info("Finished generating testing datasets");

            long[] resultTest = rddTest.reduce(new ReduceArrayFunction());

            log.info("Train - Total datasets: {}", resultTrain[0]);
            log.info("Train - Total examples: {}", resultTrain[1]);
            log.info("Train - Total bytes (DataSets on disk): {}", resultTrain[2]);
            log.info("Train - Total skipped due to length: {}", resultTrain[3]);
            log.info("Test - Total datasets: {}", resultTest[0]);
            log.info("Test - Total examples: {}", resultTest[1]);
            log.info("Test - Total bytes (DataSets on disk): {}", resultTest[2]);
            log.info("Test - Total skipped due to length: {}", resultTest[3]);

            long end = System.currentTimeMillis();
            log.info("*** Completed preprocessing of patent data in {} sec ***", (end - start) / 1000);
        }

        log.info("-- Completed All Preprocessing Steps --");
    }

    /**
     * Get a list of all URLs in a page for zip files
     */
    public static List<String> getZipUrlsFromPage(String url) {
        List<String> out = new ArrayList<>();
        try {
            Document doc = Jsoup.connect(url).get();
            Elements links = doc.select("a[href]");

            for (Element e : links) {
                String s = e.attr("href");
                if (s.endsWith(".zip")) {
                    if (s.startsWith("http")) {
                        //Absolute link
                        out.add(s);
                    } else {
                        //Relative link
                        out.add(e.baseUri() + s);
                    }
                }
            }

        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return out;
    }

    public static List<String> listPaths(JavaSparkContext sc, String path, boolean recursive) throws IOException {
        if (path.endsWith(".blob.core.windows.net/") || path.endsWith(".blob.core.windows.net")) {
            //Azure library bug: seems that we get an infinite loop if we try to list paths on the
            // root directory, for some versions of the Azure Hadoop library - deadlocks on fileIter.hasNext()
            throw new IllegalStateException("Cannot list paths from root directory due to Azure library bug");
        }

        List<String> paths = new ArrayList<>();
        Configuration config = new Configuration();
        FileSystem hdfs = FileSystem.get(URI.create(path), config);
        RemoteIterator fileIter = hdfs.listFiles(new Path(path), recursive);

        while (fileIter.hasNext()) {
            String filePath = ((LocatedFileStatus) fileIter.next()).getPath().toString();
            paths.add(filePath);
        }

        return paths;
    }

    public List<String> listPaths(int firstYearInclusive, int lastYearInclusive, JavaSparkContext sc,
            String baseDirZips) throws IOException {
        List<String> paths = new ArrayList<>();
        for (int i = firstYearInclusive; i <= lastYearInclusive; i++) {
            List<String> pathsYear = listPaths(sc, baseDirZips + i, true);
            for (String s : pathsYear) {
                if (s.endsWith(".zip")) {
                    paths.add(s);
                }
            }
        }
        log.info("Number of paths: {}", paths.size());
        return paths;
    }

    public static class DownloadToAzureFn implements VoidFunction<String> {
        private static final Configuration conf = new Configuration();

        private String rootDir;
        private boolean skipExisting;

        public DownloadToAzureFn(String rootDir, boolean skipExisting) {
            this.rootDir = rootDir;
            this.skipExisting = skipExisting;
        }

        @Override
        public void call(String url) throws Exception {
            int idx = url.indexOf("fulltext/");
            String year = url.substring(idx + 9, idx + 9 + 4);
            String filename = FilenameUtils.getName(url);

            URI outUri = new URI(rootDir + year + "/" + filename);

            FileSystem fs = FileSystem.get(outUri, conf);
            URL source = new URL(url);

            Path p = new Path(outUri);
            if (fs.exists(p)) {
                if (skipExisting) {
                    long length = fs.getFileStatus(p).getLen();
                    if (length < 1024 * 1024) {
                        //Assume it must be corrupt somehow if it's < 1MB
                        log.info("Re-downloading file of length {}: {} - {}", length, url, outUri);
                        fs.delete(p, false);
                    } else {
                        log.info("Skipping existing file: {} - {}", url, outUri);
                        return;
                    }
                }
                fs.delete(p, false);
            }

            try (FSDataOutputStream out = fs.create(p);
                    InputStream is = new BufferedInputStream(source.openStream())) {
                IOUtils.copy(is, out);
            } catch (Throwable t) {
                throw new RuntimeException("Error downloading: " + url, t);
            }
            log.info("Downloaded: {} to {}", url, outUri);
        }
    }

    private static class ReduceArrayFunction implements Function2<long[], long[], long[]> {
        @Override
        public long[] call(long[] p1, long[] p2) {
            for (int i = 0; i < p1.length; i++) {
                p1[i] += p2[i];
            }
            return p1;
        }
    }

    public static class PatentsToIndexFilesFunction implements Function<String, long[]> {
        private static Configuration config = new Configuration();
        private final int maxSequenceLength;
        private String wordVectorsPath;
        private int minibatchSize;
        private Map<String, Integer> teir2WordVectorsLabelIdx;
        private String baseOutputPath;

        private final int minTokens;
        private static WordVectors wordVectors;

        public PatentsToIndexFilesFunction(int maxSequenceLength, String wordVectorsPath, int minibatchSize,
                Map<String, Integer> teir2WordVectorsLabelIdx, String baseOutputPath, int minTokens) {
            this.maxSequenceLength = maxSequenceLength;
            this.wordVectorsPath = wordVectorsPath;
            this.minibatchSize = minibatchSize;
            this.teir2WordVectorsLabelIdx = teir2WordVectorsLabelIdx;
            this.baseOutputPath = baseOutputPath;
            this.minTokens = minTokens;
        }

        @Override
        public long[] call(String s) throws Exception {

            URI u = new URI(s);
            FileSystem fs = FileSystem.get(u, config);
            String name = FilenameUtils.getBaseName(s);
            File temp = Files.createTempFile(name, ".zip").toFile();
            temp.deleteOnExit();

            FileSystem file = null;

            long dataSetCount = 0;
            long exampleCount = 0;
            long totalSize = 0;
            long countSkippedOnSize = 0;
            try {
                try (InputStream in = new BufferedInputStream(fs.open(new Path(u)));
                        OutputStream os = new BufferedOutputStream(new FileOutputStream(temp))) {
                    IOUtils.copy(in, os);
                } catch (Throwable t) {
                    log.warn("Patent failed, skipping: {}", s, t);
                    return new long[] { 0, 0, 0, 0 };
                }

                String[] split = s.split("/");
                int year = Integer.parseInt(split[split.length - 2]);
                List<Patent> patents;
                if (year <= 2001) {
                    patents = new LegacyFormatPatentParser().parsePatentZip(temp);
                } else {
                    patents = new JSoupXmlParser().parsePatentZip(temp);
                }

                log.info("Finished loading {} patents for path {}", patents.size(), s);

                int[] order = new int[patents.size()];
                for (int i = 0; i < order.length; i++) {
                    order[i] = i;
                }
                MathUtils.shuffleArray(order, new Random());

                List<int[]> toMerge = new ArrayList<>();
                IntArrayList labelToMerge = new IntArrayList();
                IntArrayList tempIntArrayList = new IntArrayList();
                TokenizerFactory tf = new DefaultTokenizerFactory();
                WordVectors wv = WordVectorProvider.getWordVectors(config, wordVectorsPath);
                for (int idx : order) {
                    Patent p = patents.get(idx);

                    String tier2 = null;
                    try {
                        tier2 = PatentLabelGenerator.tier2Label(p.getClassificationUS());
                    } catch (Throwable t) {
                        log.warn("Skipping bad patent label: {}", p.getClassificationUS());
                        //Don't continue, in case we need to export on this one
                    }
                    if (tier2 != null && teir2WordVectorsLabelIdx.containsKey(tier2)) {
                        int labelIdx = teir2WordVectorsLabelIdx.get(tier2);
                        String text = p.getAllText();

                        List<String> tokens = tf.create(text).getTokens();
                        if (tokens.size() < minTokens) {
                            countSkippedOnSize++;
                            continue;
                        }

                        tempIntArrayList.clear();
                        for (String token : tokens) {
                            if (wv.hasWord(token)) {
                                tempIntArrayList.add(wv.indexOf(token));
                            }
                            if (tempIntArrayList.size() >= maxSequenceLength) {
                                break;
                            }
                        }

                        if (tempIntArrayList.size() < minTokens) {
                            countSkippedOnSize++;
                            continue;
                        }

                        toMerge.add(tempIntArrayList.toIntArray());
                        labelToMerge.add(labelIdx);
                    }

                    if (toMerge.size() >= minibatchSize || (order[order.length - 1] == idx && toMerge.size() > 0)) {

                        String filename = "wordIndices_" + UUID.randomUUID().toString() + ".bin";
                        URI uri = new URI(this.baseOutputPath
                                + (!this.baseOutputPath.endsWith("/") && !this.baseOutputPath.endsWith("\\") ? "/"
                                        : "")
                                + filename);
                        if (file == null) {
                            file = FileSystem.get(uri, config);
                        }

                        long writtenBytes = 0;
                        try (FSDataOutputStream out = file.create(new Path(uri))) {
                            for (int i = 0; i < toMerge.size(); i++) {
                                int[] idxs = toMerge.get(i);
                                for (int write : idxs) {
                                    out.writeInt(write);
                                    writtenBytes += 4;
                                }
                                out.writeInt(-labelToMerge.get(i));
                                exampleCount++;
                                writtenBytes += 4;
                            }
                        } catch (Throwable t) {
                            throw new RuntimeException("Error saving file: path \"" + uri + "\"", t);
                        }

                        totalSize += writtenBytes;
                        dataSetCount++;

                        toMerge.clear();
                        labelToMerge.clear();
                    }
                }
            } catch (Throwable t) {
                log.error("Error parsing: {}", s);
                throw t;
            } finally {
                temp.delete();
            }

            return new long[] { dataSetCount, exampleCount, totalSize, countSkippedOnSize };
        }
    }
}