com.google.api.services.samples.prediction.cmdline.PredictionSample.java Source code

Java tutorial

Introduction

Here is the source code for com.google.api.services.samples.prediction.cmdline.PredictionSample.java

Source

/*
 * Copyright (c) 2013 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 */

package com.google.api.services.samples.prediction.cmdline;

import com.google.api.client.auth.oauth2.Credential;
import com.google.api.client.extensions.java6.auth.oauth2.AuthorizationCodeInstalledApp;
import com.google.api.client.extensions.jetty.auth.oauth2.LocalServerReceiver;
import com.google.api.client.googleapis.auth.oauth2.GoogleAuthorizationCodeFlow;
import com.google.api.client.googleapis.auth.oauth2.GoogleClientSecrets;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.JsonParser;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.client.util.DateTime;
import com.google.api.client.util.store.DataStoreFactory;
import com.google.api.client.util.store.FileDataStoreFactory;
import com.google.api.services.prediction.Prediction;
import com.google.api.services.prediction.PredictionScopes;
import com.google.api.services.prediction.model.Input;
import com.google.api.services.prediction.model.Input.InputInput;

import com.google.api.services.prediction.model.Insert;
import com.google.api.services.prediction.model.Insert.TrainingInstances;
import com.google.api.services.prediction.model.Insert2;
import com.google.api.services.prediction.model.Insert2.ModelInfo;
import com.google.api.services.prediction.model.Output;
import com.google.api.client.http.HttpResponse;
import com.google.api.client.http.HttpResponseException;

import org.json.simple.JSONObject;
import org.json.simple.parser.ParseException;
import org.json.simple.parser.JSONParser;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * Main class for the Prediction API command line sample.
 * Demonstrates how to make an authenticated API call using OAuth 2 helper classes.
 */
public class PredictionSample {

    /**
     * Be sure to specify the name of your application. If the application name is {@code null} or
     * blank, the application will log a warning. Suggested format is "MyCompany-ProductName/1.0".
     */
    private static final String APPLICATION_NAME = "UU-lihao";
    private static final String STORAGE_DATA_LOCATION = "/language_id.txt";
    private static final String PROJECT_NAME = "858822147939";
    private static final String MODEL_ID = "modelId" + Math.random();

    /** Directory to store user credentials. */
    private static final java.io.File DATA_STORE_DIR = new java.io.File(System.getProperty("user.home"),
            ".store/prediction_sample");

    /**
     * Global instance of the {@link DataStoreFactory}. The best practice is to make it a single
     * globally shared instance across your application.
     */
    private static FileDataStoreFactory dataStoreFactory;

    /** Global instance of the JSON factory. */
    private static final JsonFactory JSON_FACTORY = JacksonFactory.getDefaultInstance();

    /** Global instance of the HTTP transport. */
    private static HttpTransport httpTransport;

    @SuppressWarnings("unused")
    private static Prediction client;

    /** Authorizes the installed application to access user's protected data. */
    private static Credential authorize() throws Exception {
        // load client secrets
        GoogleClientSecrets clientSecrets = GoogleClientSecrets.load(JSON_FACTORY,
                new InputStreamReader(PredictionSample.class.getResourceAsStream("/client_secrets.json")));
        if (clientSecrets.getDetails().getClientId().startsWith("Enter")
                || clientSecrets.getDetails().getClientSecret().startsWith("Enter ")) {
            System.out.println(
                    "Overwrite the src/main/resources/client_secrets.json file with the client secrets file "
                            + "you downloaded from the Quickstart tool or manually enter your Client ID and Secret "
                            + "from https://code.google.com/apis/console/?api=prediction#project:858822147939 "
                            + "into src/main/resources/client_secrets.json");
            System.exit(1);
        }

        // Set up authorization code flow.
        // Ask for only the permissions you need. Asking for more permissions will
        // reduce the number of users who finish the process for giving you access
        // to their accounts. It will also increase the amount of effort you will
        // have to spend explaining to users what you are doing with their data.
        // Here we are listing all of the available scopes. You should remove scopes
        // that you are not actually using.
        Set<String> scopes = new HashSet<String>();
        scopes.add(PredictionScopes.DEVSTORAGE_FULL_CONTROL);
        scopes.add(PredictionScopes.DEVSTORAGE_READ_ONLY);
        scopes.add(PredictionScopes.DEVSTORAGE_READ_WRITE);
        scopes.add(PredictionScopes.PREDICTION);

        GoogleAuthorizationCodeFlow flow = new GoogleAuthorizationCodeFlow.Builder(httpTransport, JSON_FACTORY,
                clientSecrets, scopes).setDataStoreFactory(dataStoreFactory).build();
        // authorize
        return new AuthorizationCodeInstalledApp(flow, new LocalServerReceiver()).authorize("user");
    }

    private static List<TrainingInstances> getTrainingData() throws IOException {
        List<TrainingInstances> instances = new ArrayList<TrainingInstances>();

        //stream read the data file
        InputStreamReader isr = new InputStreamReader(
                PredictionSample.class.getResourceAsStream(STORAGE_DATA_LOCATION));
        BufferedReader br = new BufferedReader(isr);

        String line = null;
        while ((line = br.readLine()) != null) {
            String partitionToken = ", ";
            int partition = line.indexOf(partitionToken);
            String output = line.substring(0, partition);
            List<Object> features = new ArrayList<Object>();
            features.add(line.substring(partition + partitionToken.length()));

            instances.add(new TrainingInstances().setOutput(output).setCsvInstance(features));
        }

        return instances;
    }

    private static Insert2 responseToObject(String jsonString) {

        Insert2 res = new Insert2();
        JSONParser parser = new JSONParser();
        try {

            JSONObject obj = (JSONObject) parser.parse(jsonString);
            DateFormat formatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'");
            res.setCreated(new DateTime((Date) formatter.parse((String) obj.get("created"))));
            res.setId((String) obj.get("id"));
            res.setKind((String) obj.get("kind"));
            res.setSelfLink((String) obj.get("selfLink"));
            res.setTrainingStatus((String) obj.get("trainingStatus"));

            if (obj.get("trainingComplete") != null) {

                res.setTrainingComplete(new DateTime((Date) formatter.parse((String) obj.get("trainingComplete"))));
                JSONObject ml = (JSONObject) obj.get("modelInfo");
                Insert2.ModelInfo modelInfo = new ModelInfo();
                modelInfo.setNumberInstances(Long.parseLong((String) ml.get("numberInstances")));
                modelInfo.setModelType((String) ml.get("modelType"));
                modelInfo.setNumberLabels(Long.parseLong((String) ml.get("numberLabels")));
                modelInfo.setClassificationAccuracy((String) ml.get("classificationAccuracy"));
                res.setModelInfo(modelInfo);

            }

        } catch (ParseException e) {
            e.printStackTrace();
            res = null;
        } catch (java.text.ParseException e) {
            e.printStackTrace();
            res = null;
        }
        return res;

    }

    private static void predict(Prediction prediction, String text) throws IOException {

        Input input = new Input();
        InputInput inputInput = new InputInput();
        inputInput.setCsvInstance(Collections.<Object>singletonList(text));
        input.setInput(inputInput);
        Output output = prediction.trainedmodels().predict(PROJECT_NAME, MODEL_ID, input).execute();

        System.out.println("Text: " + text);
        System.out.println("Predicted language: " + output.getOutputLabel());

    }

    private static void train(Prediction prediction) throws IOException {
        //start the training process of the google APIs
        //provide the training sample via embedding data inside requests
        List<TrainingInstances> instances = getTrainingData();
        Insert insert = new Insert().setTrainingInstances(instances);
        insert.setFactory(JSON_FACTORY);
        insert.setId(MODEL_ID);
        prediction.trainedmodels().insert(PROJECT_NAME, insert).execute();

        int triesCounter = 0;
        while (triesCounter < 1000) {
            try {

                HttpResponse httpResponse = prediction.trainedmodels().get(PROJECT_NAME, MODEL_ID)
                        .executeUnparsed();

                if (httpResponse.getStatusCode() == 200) {

                    Insert2 res = responseToObject(httpResponse.parseAsString());

                    if (res.getTrainingStatus().compareTo("DONE") == 0) {

                        System.out.println("training complete");
                        System.out.println(res.getModelInfo());
                        return;

                    }

                } else {

                    httpResponse.ignore();

                }

                Thread.sleep(5000 * (triesCounter + 1));
                System.out.print(".");
                System.out.flush();
                triesCounter++;

            } catch (Exception e) {

                e.printStackTrace();
                break;

            }
        }

        System.err.println("ERROR: training not complete");
        System.exit(1);

    }

    public static void main(String[] args) {
        try {
            // initialize the transport
            httpTransport = GoogleNetHttpTransport.newTrustedTransport();

            // initialize the data store factory
            dataStoreFactory = new FileDataStoreFactory(DATA_STORE_DIR);

            // authorization
            Credential credential = authorize();

            // set up global Prediction instance
            client = new Prediction.Builder(httpTransport, JSON_FACTORY, credential)
                    .setApplicationName(APPLICATION_NAME).build();

            //      System.out.println("Success! Now add code here.");
            train(client);

            String sample = "This version of the simple language";
            predict(client, sample);

        } catch (IOException e) {
            System.err.println(e.getMessage());
        } catch (Throwable t) {
            t.printStackTrace();
        }
        System.exit(1);
    }
}