com.ibm.watson.developer_cloud.professor_languo.pipeline.RnrMergerAndRanker.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.watson.developer_cloud.professor_languo.pipeline.RnrMergerAndRanker.java

Source

/*
 * Copyright IBM Corp. 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 */
package com.ibm.watson.developer_cloud.professor_languo.pipeline;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.Properties;
import java.util.Set;

import org.apache.commons.lang3.time.DurationFormatUtils;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.Credentials;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpDelete;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import rx.Observable;

import org.apache.wink.json4j.JSON;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

import com.ibm.watson.developer_cloud.professor_languo.api.AnswerMergerAndRanker;
import com.ibm.watson.developer_cloud.professor_languo.configuration.Messages;
import com.ibm.watson.developer_cloud.professor_languo.configuration.RetrieveAndRankConstants;
import com.ibm.watson.developer_cloud.professor_languo.data_model.CandidateAnswer;
import com.ibm.watson.developer_cloud.professor_languo.data_model.Question;
import com.ibm.watson.developer_cloud.professor_languo.data_model.QuestionAnswerSet.CorrectAnswer;
import com.ibm.watson.developer_cloud.professor_languo.exception.PipelineException;
import com.ibm.watson.developer_cloud.professor_languo.ingestion.RankerCreationUtil;
import com.ibm.watson.developer_cloud.professor_languo.pipeline.primary_search.RetrieveAndRankSearcherConstants;

/**
 * This class creates a ranker from the Retrieve and Rank service. It creates the training data csv
 * file located at {@code trainingFilePath} whcih is formatted as below:
 * 
 * question_id,{list of feature values separated by commas},ground_truth
 * 
 * The header of this file must be consistent during training and testing. This class is also
 * designed to work with the qa-framework
 *
 */

public class RnrMergerAndRanker implements AnswerMergerAndRanker {
    private final static Logger logger = LogManager.getLogger(RankerCreationUtil.class.getCanonicalName());

    private static final AtomicInteger QIDGenerator = new AtomicInteger(0);

    public String trainingFilePath;
    public String base_url;
    private String cluster, collection, rankerName;
    private Credentials creds;
    private int rows;
    private static String ranker_url;
    private String current_ranker_id;
    private boolean addHeader;

    private CloseableHttpClient client;
    private int count;
    private int goodRecallCount;
    private int retryLimit;
    private StringBuffer trainingData;

    /**
     * Initialize the AnswerMergerAndRanker with the {@link Properties} object obtained from the
     * properties file
     */
    @Override
    public void initialize(Properties properties) {
        creds = new UsernamePasswordCredentials(properties.getProperty(RetrieveAndRankConstants.USERNAME),
                properties.getProperty(RetrieveAndRankConstants.PASSWORD));
        cluster = properties.getProperty(RetrieveAndRankConstants.SOLR_CLUSTER_ID);
        collection = properties.getProperty(RetrieveAndRankConstants.COLLECTION);
        rankerName = properties.getProperty(RetrieveAndRankConstants.RANKER_NAME);
        rows = Integer.parseInt(properties.getProperty(RetrieveAndRankConstants.CANDIDATE_ANSWER_NUM));
        base_url = properties.getProperty(RetrieveAndRankConstants.RNR_ENDPOINT);
        trainingFilePath = properties.getProperty(RetrieveAndRankConstants.TRAINING_DATA_PATH);
        retryLimit = Integer.parseInt(properties.getProperty(RetrieveAndRankConstants.QUERY_RETRY_LIMIT));
        addHeader = true;
        trainingData = new StringBuffer();

        ranker_url = base_url + RetrieveAndRankConstants.RNR_ENDPOINT_VERSION
                + RetrieveAndRankSearcherConstants.RANKER_REQUEST_HANDLER;

        if (cluster == null || collection == null) {
            logger.info(Messages.getString("RetrieveAndRank.MISSING_PROPERTY")); //$NON-NLS-1$
            return;
        }

        logger.info(Messages.getString("RetrieveAndRank.SOLAR_CLUSTER_ID") + cluster); //$NON-NLS-1$
        logger.info(Messages.getString("RetrieveAndRank.SOLAR_COLLECTION_ID") + collection); //$NON-NLS-1$
        logger.info(Messages.getString("RetrieveAndRank.RANKER_NAME") + rankerName); //$NON-NLS-1$
        logger.info(Messages.getString("RetrieveAndRank.SOLAR_ROWS") + rows); //$NON-NLS-1$

        // Keep track of recall ratio
        count = 0;
        goodRecallCount = 0;
    }

    /**
     * If {@code correctAnswers != null}, this method will create the training data csv file from the
     * list of CandidateAnswers and return {@code answers} unchanged. Else, it will apply the ranker
     * created in the current invocation of this class.
     * 
     * @param question {@link Question} from which {@code answers} are generated
     * @param answers List of {@link CandidateAnswer}s
     * @param correctAnswers List of {@link CorrectAnswer}s to the {@code question}. There should only
     *        be one correct answer for each question.
     * 
     */
    @Override
    public Observable<CandidateAnswer> mergeAndRankAnswers(Question question, Observable<CandidateAnswer> answers,
            Collection<CorrectAnswer> correctAnswers) {

        client = RankerCreationUtil.createHttpClient(AuthScope.ANY, creds);

        if (correctAnswers != null) { // TRAINING PHASE
            train(question, answers, correctAnswers);
            return answers;
        } else { // TESTING PHASE
            return apply(question, answers);
        }
    }

    @Override
    /**
     * Runs at the end of the training phase, invoked by a QuestionAnswerer
     */
    public void finishTraining() {
        client = RankerCreationUtil.createHttpClient(AuthScope.ANY, creds);
        long startTime = System.currentTimeMillis();

        try {
            // Attempt to train the ranker with the training file
            logger.info(Messages.getString("RetrieveAndRank.RANKER_ATTEPT_CREATE")); //$NON-NLS-1$
            String result = RankerCreationUtil.trainRanker(ranker_url, rankerName, client, trainingData);

            logger.info(Messages.getString("RetrieveAndRank.RANKER_WRITE_TO_DISK_START"));
            BufferedWriter bw = new BufferedWriter(new FileWriter(trainingFilePath));
            bw.write(trainingData.toString());
            bw.close();
            logger.info(Messages.getString("RetrieveAndRank.RANKER_WRITE_TO_DISK_END"));

            // Obtain ranker_id
            JSONObject jsonResult = (JSONObject) JSON.parse(result);
            String ranker_id = (String) jsonResult.get("ranker_id");
            current_ranker_id = ranker_id;

            // Ping the ranker for its status
            String status = getRankerStatus(ranker_id);
            logger.info(Messages.getString("RetrieveAndRank.RANKER_TRAINING_KICKOFF") + trainingFilePath); //$NON-NLS-1$

            // waiting for training to complete
            while (!status.equals("Available")) {
                logger.info(Messages.getString("RetrieveAndRank.RANKER_STATUS") + status); //$NON-NLS-1$
                try {
                    Thread.sleep(30000);
                } catch (InterruptedException e) {
                    throw new RuntimeException(Messages.getString("RetrieveAndRank.RANKER_TRAINING_INTERRUPTED"),
                            e);
                }
                status = getRankerStatus(ranker_id);
                if (status.equals("Failed"))
                    throw new RuntimeException(Messages.getString("RetrieveAndRank.RANKER_TRAINING_FAIL"));
            }
            long timeTaken = System.currentTimeMillis() - startTime;
            logger.info(MessageFormat.format(Messages.getString("RetrieveAndRank.RANKER_TRAINING_FINISH"), //$NON-NLS-1$
                    ranker_id, DurationFormatUtils.formatDurationWords(timeTaken, true, false)));
        } catch (IOException | NullPointerException | JSONException e) {
            logger.error(e.getMessage());
        }
        logger.info("Ranker ID: " + current_ranker_id);
    }

    /**
     * Save feature vectors of each candidate answer to a CSV file to send to the ranker for training
     * 
     * @param question
     * @param answers
     * @param correctAnswers
     */
    private void train(Question question, Observable<CandidateAnswer> answers,
            Collection<CorrectAnswer> correctAnswers) {

        List<CandidateAnswer> candidate_answers = answers.toList().toBlocking().first();
        // There may be no candidate answers if the query is completely
        // unrelated
        // to the corpus
        if (candidate_answers.size() == 0) {
            logger.info(Messages.getString("RetrieveAndRank.RANKER_NO_CANDIDATE_ANSWER")); //$NON-NLS-1$
            return;
        }

        // Check if the corret answer exists in the list of candidate answers
        boolean correctAnswerExists = false;
        for (CandidateAnswer ca : candidate_answers) {
            if (CorrectAnswer.isCorrect(ca, correctAnswers)) {
                correctAnswerExists = true;
                break;
            }
        }
        if (correctAnswerExists) {
            // Get list of features
            CandidateAnswer sampleAnswer = candidate_answers.get(0);
            Object[] objArray = sampleAnswer.getFeatures().toArray();
            String[] features = Arrays.copyOf(objArray, objArray.length, String[].class);

            if (addHeader) {
                Set<Entry<String, Double>> headerAnswerPairs = sampleAnswer.getFeatureValuePairs();
                List<String> headers = new ArrayList<String>();
                headers.add(RetrieveAndRankConstants.QUESTION_ID_HEADER);
                Iterator<Entry<String, Double>> it = headerAnswerPairs.iterator();
                while (it.hasNext()) {
                    Entry<String, Double> entry = (Entry<String, Double>) it.next();
                    headers.add(entry.getKey());
                }
                writeHeader(headers, true);
            }

            try {
                writeAnswers(candidate_answers, correctAnswers, question, features);
            } catch (IOException e) {
                throw new RuntimeException(Messages.getString("RetrieveAndRank.RANKER_WRITE_ERROR") + e); //$NON-NLS-1$
            }

            addHeader = false;
            goodRecallCount++;
            if (goodRecallCount % 10 == 0) {
                logger.info(MessageFormat.format(Messages.getString("RetrieveAndRank.RANKER_RECALL_NUMBER"), //$NON-NLS-1$
                        goodRecallCount, count));
            }
        }
        count++;
    }

    /**
     * Rank {@code answers} to {@code question}
     * 
     * @param question {@link Question} from which {@link answers} are formed
     * @param answers list of {@link CandidateAnswer}'s
     * @return Ranked answers
     */
    private Observable<CandidateAnswer> apply(Question question, Observable<CandidateAnswer> answers) {

        try {
            // Create authorized HttpClient
            CloseableHttpClient client = RankerCreationUtil.createHttpClient(AuthScope.ANY, creds);

            // Build feature vector data for candidate answers in csv format
            String csvAnswerData = RankerCreationUtil.getCsvAnswerData(answers.toList().toBlocking().first(), null);

            // Send rank request
            String rank_request_url = ranker_url + "/" + current_ranker_id
                    + RetrieveAndRankSearcherConstants.RANK_REQUEST_HANDLER;
            JSONObject responseJSON = RankerCreationUtil.rankAnswers(client, rank_request_url, csvAnswerData);
            JSONArray rankedAnswerArray = (JSONArray) responseJSON.get("answers");

            // If there is an error with the service, wait a moment
            // and retry up to retry limit
            int retryAttempt = 1;
            while (rankedAnswerArray == null) {
                if (retryAttempt > retryLimit) {
                    throw new PipelineException(MessageFormat
                            .format(Messages.getString("RetrieveAndRank.QUERY_RETRY_FAILED"), retryAttempt)); //$NON-NLS-1$
                }
                Thread.sleep(3000);
                logger.info(MessageFormat.format(Messages.getString("RetrieveAndRank.QUERY_RETRY"), retryAttempt)); //$NON-NLS-1$
                responseJSON = RankerCreationUtil.rankAnswers(client, rank_request_url, csvAnswerData);
                rankedAnswerArray = (JSONArray) responseJSON.get("answers");
                retryAttempt++;
            }

            // Iterate through JSONArray of ranked answers and match with the
            // original
            List<CandidateAnswer> answerList = answers.toList().toBlocking().first();

            // Set confidence to the top answers chosen by the ranker,
            // ignore the rest
            List<CandidateAnswer> rankedAnswerList = new ArrayList<CandidateAnswer>();

            for (int i = 0; i < answerList.size(); i++) {
                for (int j = 0; j < rankedAnswerArray.size(); j++) {
                    JSONObject ans = (JSONObject) rankedAnswerArray.get(j);
                    // Get the answer_id
                    String answer_id = (String) ans.get(RetrieveAndRankConstants.ANSWER_ID_HEADER);
                    double confidence = (double) ans.get(RetrieveAndRankConstants.CONFIDENCE_HEADER);
                    if (answerList.get(i).getAnswerLabel().equals(answer_id)) {
                        // Set the answer's confidence
                        answerList.get(i).setConfidence(confidence);
                        rankedAnswerList.add(answerList.get(i));
                    }
                }
            }

            return Observable.from(rankedAnswerList);

        } catch (ClientProtocolException e) {
            logger.error(e.getMessage());
        } catch (IOException e) {
            logger.error(e.getMessage());
            // Something wrong with the service. Set all confidence to 0
            List<CandidateAnswer> answerList = answers.toList().toBlocking().first();
            for (CandidateAnswer answer : answerList) {
                answer.setConfidence(0);
            }
        } catch (Exception e) {
            logger.error(e.getMessage());
        }
        return answers;
    }

    /**
     * Gets the rankers associated with the given credentials
     * 
     * @return a list of ranker_ids of the active rankers
     * @throws ClientProtocolException
     * @throws IOException
     */
    public List<String> getRankers() throws ClientProtocolException, IOException {

        List<String> rankerIds = new ArrayList<String>();
        JSONArray rankers;
        // Create authorized HttpClient
        client = RankerCreationUtil.createHttpClient(AuthScope.ANY, creds);

        try {
            HttpGet httpget = new HttpGet(ranker_url);
            CloseableHttpResponse response = client.execute(httpget);
            try {
                String result = EntityUtils.toString(response.getEntity());
                JSONObject jobject = (JSONObject) JSON.parse(result);
                rankers = jobject.getJSONArray("rankers");

                for (int i = 0; i < rankers.size(); i++) {
                    rankerIds.add((String) ((JSONObject) rankers.get(i)).get("ranker_id"));
                }
            } catch (NullPointerException | JSONException e) {
                logger.error(e.getMessage());
            }

            finally {
                response.close();
            }
        }

        finally {
            client.close();
        }

        return rankerIds;
    }

    /**
     * Returns the status of the specified ranker
     * 
     * @param ranker_id of the ranker
     * @return status. e.g. "Available", "Training", "Failed"
     * @throws IOException
     */
    public String getRankerStatus(String ranker_id) throws IOException {

        String status = null;
        JSONObject res;
        // Create authorized HttpClient
        client = RankerCreationUtil.createHttpClient(AuthScope.ANY, creds);

        try {
            HttpGet httpget = new HttpGet(ranker_url + "/" + ranker_id);
            CloseableHttpResponse response = client.execute(httpget);

            try {

                String result = EntityUtils.toString(response.getEntity());
                res = (JSONObject) JSON.parse(result);
                status = (String) res.get("status");

            } catch (NullPointerException | JSONException e) {
                logger.error(e.getMessage());
            }

            finally {
                response.close();
            }
        }

        finally {
            client.close();
        }

        return status;
    }

    /**
     * Deletes the specified ranker
     * 
     * @param ranker_id ofthe ranker to be deleted
     * @throws ClientProtocolException
     * @throws IOException
     * @throws JSONException
     */
    public static void deleteRanker(CloseableHttpClient client, String ranker_id)
            throws ClientProtocolException, IOException {

        JSONObject res;

        try {
            HttpDelete httpdelete = new HttpDelete(ranker_url + "/" + ranker_id);
            httpdelete.setHeader("Content-Type", "application/json");
            CloseableHttpResponse response = client.execute(httpdelete);

            try {

                String result = EntityUtils.toString(response.getEntity());
                res = (JSONObject) JSON.parse(result);
                if (res.isEmpty()) {
                    logger.info(MessageFormat.format(Messages.getString("RetrieveAndRank.RANKER_DELETE"), //$NON-NLS-1$
                            ranker_id));
                } else {
                    logger.info(MessageFormat.format(Messages.getString("RetrieveAndRank.RANKER_DELETE_FAIL"), //$NON-NLS-1$
                            ranker_id));
                }
            } catch (NullPointerException | JSONException e) {
                logger.error(e.getMessage());
            }

            finally {
                response.close();
            }
        }

        finally {
            client.close();
        }
    }

    /**
     * This method writes the header required by RaaS training scripts, Retrieve & Rank service, etc.
     * to the training data file.
     * 
     * @param trainFile
     * @param numOfFeatures
     */
    private void writeHeader(List<String> headers, boolean addGroundTruth) {
        for (int i = 0; i < headers.size(); i++) {
            trainingData.append(headers.get(i) + ",");
        }
        if (addGroundTruth) {
            trainingData.append("ground_truth");
            trainingData.append("\n");
        } else {
            trainingData.delete(trainingData.length() - 1, trainingData.length());
            trainingData.append("\n");
        }
    }

    /**
     * This method writes the header required by RaaS training scripts, Retrieve & Rank service, etc.
     * to the training data file.
     * 
     * @param trainFile
     * @param numOfFeatures
     */
    private void writeAnswers(List<CandidateAnswer> answers, Collection<CorrectAnswer> correctAnswers,
            Question question, String[] features) throws IOException {

        int qid = getQID();
        for (CandidateAnswer answer : answers) {

            trainingData.append(qid);

            for (int i = 0; i < features.length; i++) {
                Double featureValue = answer.getFeatureValue(features[i]);
                if (featureValue == null)
                    featureValue = 0d;
                trainingData.append("," + featureValue);
            }
            trainingData.append("," + (CorrectAnswer.isCorrect(answer, correctAnswers) == true ? 1 : 0));
            trainingData.append("\n");
        }
    }

    /**
     * Helper method that assigns QIDs to each thread (per question) in serial order (as required by
     * RaaS)
     * 
     * @return the QID to be used for feature-vectors of the candidate answers of the current question
     */

    public synchronized int getQID() {
        return QIDGenerator.incrementAndGet();
    }

}