org.deeplearning4j.ui.nearestneighbors.NearestNeighborsResource.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.ui.nearestneighbors.NearestNeighborsResource.java

Source

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.ui.nearestneighbors;

import io.dropwizard.views.View;

import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.util.*;

import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.ui.api.UrlResource;
import org.deeplearning4j.ui.uploads.FileResource;
import org.deeplearning4j.util.SerializationUtils;

/**
 * Nearest neighbors
 *
 * @author Adam Gibson
 */
@Path("/nearestneighbors")
public class NearestNeighborsResource extends FileResource {
    private VPTree tree;
    private List<VocabWord> words;
    private Map<Integer, VocabWord> theVocab;
    private VocabCache vocab;
    private WordVectors wordVectors;
    private File localFile;

    /**
     * The file path for uploads
     *y
     * @param filePath the file path for uploads
     */
    public NearestNeighborsResource(String filePath) {
        super(filePath);
    }

    @GET
    public View get() {
        return new NearestNeighborsView();
    }

    @POST
    @Path("/update")
    @Produces(MediaType.APPLICATION_JSON)
    public Response updateFilePath(UrlResource resource) {
        if (!resource.getUrl().startsWith("http")) {
            this.localFile = new File(".", resource.getUrl());
            handleUpload(localFile);
        } else {
            File dl = new File(filePath, UUID.randomUUID().toString());
            try {
                FileUtils.copyURLToFile(new URL(resource.getUrl()), dl);
            } catch (Exception e) {
                e.printStackTrace();
            }

            handleUpload(dl);

        }

        return Response.ok(Collections.singletonMap("message", "Uploaded file")).build();
    }

    @POST
    @Path("/vocab")
    @Produces(MediaType.APPLICATION_JSON)
    public Response getVocab() {
        List<String> words = new ArrayList<>();

        if (wordVectors != null) {
            words.addAll(wordVectors.vocab().words());
        } else {
            for (VocabWord word : this.words) {
                words.add(word.getWord());
            }
        }

        return Response.ok((new ArrayList<>(words))).build();
    }

    @POST
    @Produces(MediaType.APPLICATION_JSON)
    @Path("/words")
    public Response getWords(NearestNeighborsQuery query) {
        Map<String, Double> map = new HashMap<>();

        if (wordVectors != null) {
            Collection<String> words = wordVectors.wordsNearest(query.getWord(), query.getNumWords());
            for (String word : words) {
                map.put(word, wordVectors.similarity(query.getWord(), word));
            }
        } else {
            List<DataPoint> results = new ArrayList<>();
            List<Double> distances = new ArrayList<>();
            tree.search(tree.getItems().get(vocab.indexOf(query.getWord())), query.getNumWords(), results,
                    distances);
            for (int i = 0; i < results.size(); i++) {
                map.put(theVocab.get(results.get(i).getIndex()).getWord(), distances.get(i));
            }
        }

        return Response.ok(map).build();
    }

    @Override
    public void handleUpload(File path) {
        try {
            if (path.getAbsolutePath().endsWith(".ser")) {
                WordVectors vectors = SerializationUtils.readObject(path);
                InMemoryLookupTable table = (InMemoryLookupTable) vectors.lookupTable();
                tree = new VPTree(table.getSyn0(), "dot", true);
                words = new ArrayList<>(vectors.vocab().vocabWords());
                theVocab = new HashMap<>();

                for (VocabWord word : words) {
                    theVocab.put(word.getIndex(), word);
                }
                this.vocab = vectors.vocab();

            } else if (path.getAbsolutePath().contains("Google")) {
                WordVectors vectors = WordVectorSerializer.loadGoogleModel(path, true);
                this.wordVectors = vectors;
            }

            else {
                Pair<InMemoryLookupTable, VocabCache> vocab = WordVectorSerializer.loadTxt(path);
                this.wordVectors = WordVectorSerializer.fromPair(vocab);

            }

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}