Java tutorial
/* Copyright IBM Corp. 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.ibm.watson.app.qaclassifier.tools; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.List; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.GnuParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVRecord; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.ibm.watson.app.qaclassifier.rest.model.ManagedAnswer; import com.ibm.watson.app.common.services.nlclassifier.model.NLClassifierTrainingData; import com.ibm.watson.app.qaclassifier.util.rest.MessageKey; /** * This is a tool/utility class that reads in both an answer.csv file and a questions.csv file and generates * the classifier training json file and a json file for populating the answer store. The 2 are generated * together to make sure that they are in sync and that every class in the training file has an associated answer * in the answer store. * * The CSV formats expected are: * questions.csv: QuestionText, LabelId * answers.csv: LabelId, CanonicalQuestion * * The LabelId values should match between the 2 files, any LabelIds in the questions.csv that do not appear in the * answers.csv will NOT be output in the training file. Any answers in the answers.csv that do not have an AnswerValue (blank) * will not be populated into the database and any questions with that LabelId will also be excluded. * * This is meant to be run on the command line at development time, its not a run time utility * * @author Stephan J Roorda * */ public class GenerateTrainingAndPopulationData { // JSON OBJECTS /* static class TrainingData { String language; List<TrainingInstance> training_data; public void setLanguage(String language) { this.language = language; } public void setInstances(List<TrainingInstance> instances) { this.training_data = instances; } public void add(TrainingInstance instance) { if( training_data == null ) { training_data = new ArrayList<TrainingInstance>(); } training_data.add(instance); } } static class TrainingInstance { String text; List<String> classes; public void setText(String text) { this.text = text; } public void setLabels(List<String> labels) { this.classes = labels; } public void addLabel(String label) { if( classes == null ) { classes = new ArrayList<String>(); } classes.add(label); } } */ /* static class Answer { String text; String className; String canonicalQuestion; String type = "TEXT"; public void setText(String text) { this.text = text; } public void setClassName(String className) { this.className = className; } public void setCanonicalQuestion(String canonical) { this.canonicalQuestion = canonical; } } */ // options for command line parameters private static final String QUESTION_INPUT = "qin", QUESTION_INPUT_LONG = "questionInput"; private static final String QUESTION_OUTPUT = "qout", QUESTION_OUTPUT_LONG = "questionOutput"; private static final String ANSWER_INPUT = "ain", ANSWER_INPUT_LONG = "answerInput"; private static final String ANSWER_TEXT_DIR = "adir", ANSWER_TEXT_DIR_LONG = "answerTextDirectory"; private static final String ANSWER_OUTPUT = "aout", ANSWER_OUTPUT_LONG = "answerOutput"; // the input and output files that we need static File questionInput = null; static File questionOutput = null; static File answerInput = null; static File answerTextDirectory = null; static File answerOutput = null; public static void main(String[] args) throws IOException { System.out.println(MessageKey.AQWQAC20007I_starting_generate_training_and_populating.getMessage() .getFormattedMessage()); // handle reading the command line parameters and initializing the files readCommandLineParameters(args); System.out.println(MessageKey.AQWQAC20008I_cmd_line_param_read.getMessage().getFormattedMessage()); // process the answers input file and create the in-memory store for it List<ManagedAnswer> answers = PopulateAnswerStore.loadAnswerStore(answerInput.getPath(), answerTextDirectory.getPath()); if (answers == null || answers.size() == 0) { System.err.println( MessageKey.AQWQAC24010E_answer_store_unable_to_load.getMessage().getFormattedMessage()); System.exit(0); } System.out.println(MessageKey.AQWQAC20004I_answer_input_file_read.getMessage().getFormattedMessage()); // process the questions input file and create the in-memory store for it NLClassifierTrainingData training = readQuestionInput(answers); if (training == null || training.getTrainingData() == null || training.getTrainingData().size() == 0) { System.err.println( MessageKey.AQWQAC24010E_answer_store_unable_to_load.getMessage().getFormattedMessage()); System.exit(0); } System.out.println(MessageKey.AQWQAC24005I_question_input_file_read.getMessage().getFormattedMessage()); // write the answer store population file // create the gson object that is doing all the writing Gson gson = new GsonBuilder().setPrettyPrinting().create(); writeGSON(gson.toJson(answers), answerOutput); System.out.println(MessageKey.AQWQAC24006I_answer_output_file_written.getMessage().getFormattedMessage()); // write the classifier training file writeGSON(training.toJson(), questionOutput); System.out.println(MessageKey.AQWQAC24007I_training_data_file_written.getMessage().getFormattedMessage()); } /** * Reads in the question input file and creates a POJO for each question it finds. If the label associated with * the question does not exist in the previously read in answer store then it is skipped * * @return TrainingData - full POJO of the training data */ private static NLClassifierTrainingData readQuestionInput(List<ManagedAnswer> store) { NLClassifierTrainingData data = null; try (FileReader reader = new FileReader(questionInput); CSVParser parser = new CSVParser(reader, CSVFormat.EXCEL)) { // read in the csv file and get the records List<CSVRecord> records = parser.getRecords(); // now we can create the training data because we have read the records data = new NLClassifierTrainingData(); data.setLanguage("en"); for (CSVRecord r : records) { // order is: QuestionText, LabelId // check for existence of label first, if not there, skip // we only add the training instance if there is an associated answer String text = r.get(0); String label = r.get(1); if (labelHasAnswer(label, store)) { data.addTrainingData(text, label); } else { System.out.println(MessageKey.AQWQAC24009E_label_not_found_in_answer_store_including_2 .getMessage(text, label).getFormattedMessage()); } } } catch (Exception e) { e.printStackTrace(); } return data; } private static boolean labelHasAnswer(String label, List<ManagedAnswer> answers) { boolean result = false; for (ManagedAnswer a : answers) { if (a.getClassName().equals(label)) { result = true; break; } } return result; } private static void writeGSON(String src, File output) throws IOException { FileWriter writer = null; try { writer = new FileWriter(output); writer.write(src); writer.flush(); } catch (IOException e) { e.printStackTrace(); } finally { try { writer.close(); } catch (IOException e) { e.printStackTrace(); } } } /** * This method handles all of the parsing and validation of the required input parameters. If any of the files do not exist or cannot be created * in the specified location then the method will fail. If we return from this method then we have all the files properly. * * @param args * @throws IllegalArgumentException */ private static void readCommandLineParameters(String[] args) throws IllegalArgumentException { Option questionInputOption = createOption(QUESTION_INPUT, QUESTION_INPUT_LONG, true, "input csv file containing questions and labels", true, QUESTION_INPUT_LONG); Option questionOutputOption = createOption(QUESTION_OUTPUT, QUESTION_OUTPUT_LONG, true, "filename and location for the classifier training data", true, QUESTION_OUTPUT_LONG); Option answerInputOption = createOption(ANSWER_INPUT, ANSWER_INPUT_LONG, true, "input csv file containing answers data", true, ANSWER_INPUT_LONG); Option answerDirectoryOption = createOption(ANSWER_TEXT_DIR, ANSWER_TEXT_DIR_LONG, true, "directory containing answer html files", true, ANSWER_TEXT_DIR_LONG); Option answerOutputOption = createOption(ANSWER_OUTPUT, ANSWER_OUTPUT_LONG, true, "filename and location for the answer store population data", true, ANSWER_OUTPUT_LONG); final Options options = buildOptions(questionInputOption, questionOutputOption, answerInputOption, answerDirectoryOption, answerOutputOption); CommandLine cmd; try { CommandLineParser parser = new GnuParser(); cmd = parser.parse(options, args); } catch (ParseException e) { System.err.println(MessageKey.AQWQAC24008E_could_not_parse_cmd_line_args_1.getMessage(e.getMessage()) .getFormattedMessage()); HelpFormatter formatter = new HelpFormatter(); formatter.printHelp(120, "java " + GenerateTrainingAndPopulationData.class.getName(), null, options, null); return; } // before we do anything else make sure we can read and write all of the necessary files final String questionInputFile = cmd.getOptionValue(QUESTION_INPUT).trim(); final String questionOutputFile = cmd.getOptionValue(QUESTION_OUTPUT).trim(); final String answerInputFile = cmd.getOptionValue(ANSWER_INPUT).trim(); final String answerDirectoryFile = cmd.getOptionValue(ANSWER_TEXT_DIR).trim(); final String answerOutputFile = cmd.getOptionValue(ANSWER_OUTPUT).trim(); // make sure we have all 5 parameters if (questionInputFile.isEmpty() || questionOutputFile.isEmpty() || answerInputFile.isEmpty() || answerDirectoryFile.isEmpty() || answerOutputFile.isEmpty()) { throw new IllegalArgumentException( MessageKey.AQWQAC14200E_must_specify_4_files.getMessage().getFormattedMessage()); } // make sure the question input file exists questionInput = new File(questionInputFile); if (!questionInput.exists()) { throw new IllegalArgumentException(MessageKey.AQWQAC14201E_file_does_not_exist_1 .getMessage(questionInput.getAbsolutePath()).getFormattedMessage()); } // make sure the answer input file exists answerInput = new File(answerInputFile); if (!answerInput.exists()) { throw new IllegalArgumentException(MessageKey.AQWQAC14201E_file_does_not_exist_1 .getMessage(answerInput.getAbsolutePath()).getFormattedMessage()); } // make sure the answer text directory exists answerTextDirectory = new File(answerDirectoryFile); if (!answerTextDirectory.exists()) { throw new IllegalArgumentException(MessageKey.AQWQAC14201E_file_does_not_exist_1 .getMessage(answerTextDirectory.getAbsolutePath()).getFormattedMessage()); } // make sure we can create the question output file questionOutput = new File(questionOutputFile); if ((null != questionOutput.getParentFile()) && !questionOutput.getParentFile().exists() && !questionOutput.getParentFile().mkdirs()) { throw new IllegalArgumentException(MessageKey.AQWQAC14202E_unable_create_parent_dir_for_file_1 .getMessage(questionOutput.getAbsolutePath()).getFormattedMessage()); } // make sure we can create the question output file answerOutput = new File(answerOutputFile); if ((null != answerOutput.getParentFile()) && !answerOutput.getParentFile().exists() && !answerOutput.getParentFile().mkdirs()) { throw new IllegalArgumentException(MessageKey.AQWQAC14202E_unable_create_parent_dir_for_file_1 .getMessage(answerOutput.getAbsolutePath()).getFormattedMessage()); } } private static Option createOption(String opt, String longOpt, boolean hasArg, String description, boolean required, String argName) { Option option = new Option(opt, longOpt, hasArg, description); option.setRequired(required); option.setArgName(argName); return option; } private static Options buildOptions(Option option, Option... additionalOptions) { final Options options = new Options(); options.addOption(option); for (Option o : additionalOptions) { options.addOption(o); } return options; } }