structuredPredictionNLG.DatasetParser.java Source code

Java tutorial

Introduction

Here is the source code for structuredPredictionNLG.DatasetParser.java

Source

/* 
 * Copyright (C) 2016 Gerasimos Lampouras
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package structuredPredictionNLG;

import gnu.trove.map.hash.TObjectDoubleHashMap;
import imitationLearning.JLOLS;
import imitationLearning.LossFunction;
import jarow.Instance;
import jarow.JAROW;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import simpleLM.SimpleLM;

/**
 * This is an abstract specification of a DatasetParser.
 *
 * @author Gerasimos Lampouras
 * @organization University of Sheffield
 */
public abstract class DatasetParser {

    static final int SEED = 13;
    static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors();

    // A random number generator
    private static Random randomGen;

    public static void resetRandomGen() {
        randomGen = new Random(SEED);
    }

    public static Random getRandomGen() {
        return randomGen;
    }

    // The identifier of the dataset
    private String dataset;
    // Training, validation, and testing subsets of the dataset    
    private ArrayList<DatasetInstance> trainingData;
    private ArrayList<DatasetInstance> validationData;
    private ArrayList<DatasetInstance> testingData;
    // Content and word actions, as observed from the dataset   
    private HashMap<String, HashSet<String>> availableContentActions;
    private HashMap<String, HashMap<String, HashSet<Action>>> availableWordActions;
    //A flag determining whether stored caches should be reset or not
    private boolean resetStoredCaches;
    // A flag determining whether detailedResults per predicate should be calculated
    private boolean calculateResultsPerPredicate;
    // A flag determining whether to perform validation on the training, validation, or test data
    private String performEvaluationOn;
    // A flag determining whether to use random or naive alignments
    private String useAlignments;
    // Maximum lengths for content and word sequences, as observed in the training data
    private int maxContentSequenceLength;
    private int maxWordSequenceLength;
    // AROW parameters
    private boolean averaging;
    private boolean shuffling;
    private int rounds;
    private Double initialTrainingParam;
    private Double additionalTrainingParam;
    private boolean adapt;
    /*
     * The sequence in which the attribute/value pairs of the meaning representation have been mentioned in the reference.
     * Primarily used to determine the order in which values will be picked during generation.
     */
    private ArrayList<ArrayList<String>> observedAttrValueSequences;

    // Lists of all observed predicates, attributes, attribute/value pairs, and overall instances in the dataset
    private ArrayList<String> predicates = new ArrayList<>();
    private HashMap<String, HashSet<String>> attributes = new HashMap<>();
    private HashMap<String, HashSet<String>> attributeValuePairs = new HashMap<>();
    private HashMap<String, ArrayList<DatasetInstance>> datasetInstances = new HashMap<>();

    // Map of alignments between attribute values and subphrases of the references; naively calculated (e.g. through edit distance)
    private HashMap<String, HashMap<ArrayList<String>, Double>> valueAlignments = new HashMap<>();
    // Patterns describing the observed context of punctuation in the dataset references
    private HashMap<String, HashMap<ArrayList<Action>, Action>> punctuationPatterns = new HashMap<>();

    // The generated training data for content and word actions
    private HashMap<String, ArrayList<Instance>> predicateContentTrainingData;
    private HashMap<String, HashMap<String, ArrayList<Instance>>> predicateWordTrainingData;

    // Language models estimated on the word and content actions; one for each predicate
    private HashMap<String, SimpleLM> contentLMsPerPredicate = new HashMap<>();
    private HashMap<String, SimpleLM> wordLMsPerPredicate = new HashMap<>();

    // Map between 
    private HashMap<String, String> compositeWordsInData = new HashMap<>();

    /**
     * Main constructor
     * @param args Console arguments
     */
    public DatasetParser(String[] args) {
        resetRandomGen();

        Option helpOption = new Option("h", "help", false, "shows this message");
        Option datasetPathOption = new Option("d", "dataset", true, "dataset identifier");
        Option pOption = new Option("p", "p", true, "decaying factor for each epoch of LOLS: {0.0 ... 1.0}");
        Option sentenceCorrectionStepsOption = new Option("s", "scsteps", true,
                "number of steps to take after sentence correction");
        Option lossFunctionOption = new Option("l", "loss", true, "loss function to use: {B, R, BR, BC, RC, BRC}");
        Option resetStoredCachesOption = new Option("rst", "reset", false, "whether to reset any stored caches");
        Option testingOption = new Option("e", "evaluate", true,
                "whether to evaluate on training, validation or testing data: {train, valid, test}");
        Option alignmentOption = new Option("a", "alignment", true,
                "whether to employ random or naive alignments: {random, naive}");
        Option detailsOption = new Option("perPred", "resultsPerPredicate", false,
                "whether to also perform evaluation on each predicate seperately");
        Option averagingOption = new Option("avg", "averaging", false,
                "AROW parameter: whether to average over all intermediate weight vectors, instead of using the final weight vectors");
        Option shufflingOption = new Option("sh", "shuffling", false,
                "AROW parameter: whether to shuffle the training instances");
        Option roundsOption = new Option("r", "rounds", true,
                "AROW parameter: for how many rounds to repeat the training");
        Option initialTrainingParamOption = new Option("init", "initialTrainingParam", true,
                "AROW parameter: training parameter for the initial policies");
        Option additionalTrainingParamOption = new Option("add", "additionalTrainingParam", true,
                "AROW parameter: training parameter for the additional policies");
        Option adaptOption = new Option("ad", "adapt", false,
                "AROW parameter: whether to use AROW algorithm or passive aggressive-II with prediction-based updates");

        Options options = new Options();
        options.addOption(helpOption);
        options.addOption(datasetPathOption);
        options.addOption(pOption);
        options.addOption(sentenceCorrectionStepsOption);
        options.addOption(lossFunctionOption);
        options.addOption(resetStoredCachesOption);
        options.addOption(testingOption);
        options.addOption(alignmentOption);
        options.addOption(detailsOption);
        options.addOption(averagingOption);
        options.addOption(shufflingOption);
        options.addOption(roundsOption);
        options.addOption(initialTrainingParamOption);
        options.addOption(additionalTrainingParamOption);
        options.addOption(adaptOption);

        CommandLineParser parser = new BasicParser();
        try {
            CommandLine cmd = parser.parse(options, args);
            if (cmd.hasOption("help")) {
                HelpFormatter formatter = new HelpFormatter();
                formatter.printHelp("sb server", options);
                System.exit(1);
            } else {
                if (cmd.hasOption("dataset")) {
                    setDataset(cmd.getOptionValue("dataset"));
                } else {
                    setDataset("hotel");
                }
                if (cmd.hasOption("p")) {
                    JLOLS.p = Double.parseDouble(cmd.getOptionValue("p"));
                }
                if (cmd.hasOption("scsteps")) {
                    JLOLS.sentenceCorrectionFurtherSteps = Integer.parseInt(cmd.getOptionValue("scsteps"));
                }
                LossFunction.metric = "BRC";
                if (cmd.hasOption("loss")) {
                    String opt = cmd.getOptionValue("loss");
                    if (!opt.isEmpty() && (opt.equals("B") || opt.equals("R") || opt.equals("BC")
                            || opt.equals("RC") || opt.equals("BRC") || opt.equals("BR"))) {
                        LossFunction.metric = opt;
                    }
                }
                if (cmd.hasOption("reset")) {
                    setResetStoredCaches(true);
                } else {
                    setResetStoredCaches(false);
                }
                setPerformEvaluationOn("test");
                if (cmd.hasOption("test")) {
                    String opt = cmd.getOptionValue("test");
                    if (!opt.isEmpty() && (opt.equals("test") || opt.equals("valid") || opt.equals("train"))) {
                        setPerformEvaluationOn(opt);
                    }
                }
                setUseAlignments("naive");
                if (cmd.hasOption("alignment")) {
                    String opt = cmd.getOptionValue("alignment");
                    if (!opt.isEmpty() && (opt.equals("naive") || opt.equals("random"))) {
                        setPerformEvaluationOn(opt);
                    }
                }
                if (cmd.hasOption("resultsPerPredicate")) {
                    setCalculateResultsPerPredicate(true);
                } else {
                    setCalculateResultsPerPredicate(false);
                }
                if (cmd.hasOption("averaging")) {
                    setAveraging(true);
                } else {
                    setAveraging(false);
                }
                if (cmd.hasOption("shuffling")) {
                    setShuffling(true);
                } else {
                    setShuffling(false);
                }
                if (cmd.hasOption("rounds")) {
                    setRounds(Integer.parseInt(cmd.getOptionValue("rounds")));
                } else {
                    setRounds(10);
                }
                if (cmd.hasOption("initialTrainingParam")) {
                    setInitialTrainingParam(Double.parseDouble(cmd.getOptionValue("initialTrainingParam")));
                } else {
                    setInitialTrainingParam(100.0);
                }
                if (cmd.hasOption("additionalTrainingParam")) {
                    setAdditionalTrainingParam(Double.parseDouble(cmd.getOptionValue("additionalTrainingParam")));
                } else {
                    setAdditionalTrainingParam(100.0);
                }
                if (cmd.hasOption("adapt")) {
                    setAdapt(true);
                } else {
                    setAdapt(false);
                }
            }
        } catch (ParseException ex) {
            System.out.println(ex.getMessage());
            System.out.println("Try \"--help\" option for details.");
        }
        trainingData = new ArrayList<>();
        validationData = new ArrayList<>();
        testingData = new ArrayList<>();

        observedAttrValueSequences = new ArrayList<>();
    }

    /**
     * In this method, the dataset should be parsed and the predicate, attribute, attribute/value, and value alignment collections should be populated.
     * Here, the data should also be split in training, validation, and testing subsets.
     */
    abstract public void parseDataset();

    /**
     * During this method, we need to calculate the alignments (naive or random), the language models, the available content and word actions, and finally the feature vectors.
     */
    abstract public void createTrainingData();

    /**
     * Creates and call the imitation learning engine
     */
    public void performImitationLearning() {
        JLOLS ILEngine = new JLOLS(this);
        if (getPerformEvaluationOn().equals("train")) {
            ILEngine.runLOLS(getTrainingData());
        } else if (getPerformEvaluationOn().equals("valid")) {
            ILEngine.runLOLS(getValidationData());
        } else if (getPerformEvaluationOn().equals("test")) {
            ILEngine.runLOLS(getTestingData());
        }
    }

    /**
     *
     * @param classifierAttrs
     * @param classifierWords
     * @param testingData
     * @param epoch
     * @return
     */
    abstract public Double evaluateGeneration(HashMap<String, JAROW> classifierAttrs,
            HashMap<String, HashMap<String, JAROW>> classifierWords, ArrayList<DatasetInstance> testingData,
            int epoch);

    /**
     *
     * @param trainingData
     */
    abstract public void createNaiveAlignments(ArrayList<DatasetInstance> trainingData);

    /**
     *
     * @param predicate
     * @param bestAction
     * @param previousGeneratedAttrs
     * @param attrValuesAlreadyMentioned
     * @param attrValuesToBeMentioned
     * @param MR
     * @param availableAttributeActions
     * @return
     */
    abstract public Instance createContentInstance(String predicate, String bestAction,
            ArrayList<String> previousGeneratedAttrs, HashSet<String> attrValuesAlreadyMentioned,
            HashSet<String> attrValuesToBeMentioned, MeaningRepresentation MR,
            HashMap<String, HashSet<String>> availableAttributeActions);

    /**
     *
     * @param predicate
     * @param costs
     * @param previousGeneratedAttrs
     * @param attrValuesAlreadyMentioned
     * @param attrValuesToBeMentioned
     * @param availableAttributeActions
     * @param MR
     * @return
     */
    abstract public Instance createContentInstanceWithCosts(String predicate, TObjectDoubleHashMap<String> costs,
            ArrayList<String> previousGeneratedAttrs, HashSet<String> attrValuesAlreadyMentioned,
            HashSet<String> attrValuesToBeMentioned, HashMap<String, HashSet<String>> availableAttributeActions,
            MeaningRepresentation MR);

    /**
     *
     * @param predicate
     * @param bestAction
     * @param previousGeneratedAttributes
     * @param previousGeneratedWords
     * @param nextGeneratedAttributes
     * @param attrValuesAlreadyMentioned
     * @param attrValuesThatFollow
     * @param wasValueMentioned
     * @param availableWordActions
     * @return
     */
    abstract public Instance createWordInstance(String predicate, Action bestAction,
            ArrayList<String> previousGeneratedAttributes, ArrayList<Action> previousGeneratedWords,
            ArrayList<String> nextGeneratedAttributes, HashSet<String> attrValuesAlreadyMentioned,
            HashSet<String> attrValuesThatFollow, boolean wasValueMentioned,
            HashMap<String, HashSet<Action>> availableWordActions);

    /**
     *
     * @param predicate
     * @param currentAttrValue
     * @param costs
     * @param generatedAttributes
     * @param previousGeneratedWords
     * @param nextGeneratedAttributes
     * @param attrValuesAlreadyMentioned
     * @param attrValuesThatFollow
     * @param wasValueMentioned
     * @param availableWordActions
     * @return
     */
    abstract public Instance createWordInstanceWithCosts(String predicate, String currentAttrValue,
            TObjectDoubleHashMap<String> costs, ArrayList<String> generatedAttributes,
            ArrayList<Action> previousGeneratedWords, ArrayList<String> nextGeneratedAttributes,
            HashSet<String> attrValuesAlreadyMentioned, HashSet<String> attrValuesThatFollow,
            boolean wasValueMentioned, HashMap<String, HashSet<Action>> availableWordActions);

    /**
     *
     * @param phrase
     * @param subPhrase
     * @return
     */
    public boolean endsWith(ArrayList<String> phrase, ArrayList<String> subPhrase) {
        if (subPhrase.size() > phrase.size()) {
            return false;
        }
        for (int i = 0; i < subPhrase.size(); i++) {
            if (!subPhrase.get(subPhrase.size() - 1 - i).equals(phrase.get(phrase.size() - 1 - i))) {
                return false;
            }
        }
        return true;
    }

    /**
     *
     * @param attribute
     * @param attrValuesToBeMentioned
     * @return
     */
    public String chooseNextValue(String attribute, HashSet<String> attrValuesToBeMentioned) {
        HashMap<String, Integer> relevantValues = new HashMap<>();
        attrValuesToBeMentioned.stream().forEach((attrValue) -> {
            String attr = attrValue.substring(0, attrValue.indexOf('='));
            String value = attrValue.substring(attrValue.indexOf('=') + 1);
            if (attr.equals(attribute)) {
                relevantValues.put(value, 0);
            }
        });
        if (!relevantValues.isEmpty()) {
            if (relevantValues.keySet().size() == 1) {
                for (String value : relevantValues.keySet()) {
                    return value;
                }
            } else {
                String bestValue = "";
                int minIndex = Integer.MAX_VALUE;
                for (String value : relevantValues.keySet()) {
                    if (value.startsWith("x")) {
                        int vI = Integer.parseInt(value.substring(1));
                        if (vI < minIndex) {
                            minIndex = vI;
                            bestValue = value;
                        }
                    }
                }
                if (!bestValue.isEmpty()) {
                    return bestValue;
                }
                for (ArrayList<String> mentionedValueSeq : observedAttrValueSequences) {
                    boolean doesSeqContainValues = true;
                    minIndex = Integer.MAX_VALUE;
                    for (String value : relevantValues.keySet()) {
                        int index = mentionedValueSeq.indexOf(attribute + "=" + value);
                        if (index != -1 && index < minIndex) {
                            minIndex = index;
                            bestValue = value;
                        } else if (index == -1) {
                            doesSeqContainValues = false;
                        }
                    }
                    if (doesSeqContainValues) {
                        relevantValues.put(bestValue, relevantValues.get(bestValue) + 1);
                    }
                }
                int max = -1;
                for (String value : relevantValues.keySet()) {
                    if (relevantValues.get(value) > max) {
                        max = relevantValues.get(value);
                        bestValue = value;
                    }
                }
                return bestValue;
            }
        }
        return "";
    }

    /**
     *
     * @param di
     * @param wordSequence
     * @return
     */
    abstract public String postProcessWordSequence(DatasetInstance di, ArrayList<Action> wordSequence);

    /**
     *
     * @param mr
     * @param refSeq
     * @return
     */
    abstract public String postProcessRef(MeaningRepresentation mr, ArrayList<Action> refSeq);

    /**
     *
     * @param dataSize
     * @param trainedAttrClassifiers_0
     * @param trainedWordClassifiers_0
     * @return
     */
    abstract public boolean loadInitClassifiers(int dataSize, HashMap<String, JAROW> trainedAttrClassifiers_0,
            HashMap<String, HashMap<String, JAROW>> trainedWordClassifiers_0);

    /**
     *
     * @param dataSize
     * @param trainedAttrClassifiers_0
     * @param trainedWordClassifiers_0
     */
    abstract public void writeInitClassifiers(int dataSize, HashMap<String, JAROW> trainedAttrClassifiers_0,
            HashMap<String, HashMap<String, JAROW>> trainedWordClassifiers_0);

    public ArrayList<DatasetInstance> getTrainingData() {
        return trainingData;
    }

    public ArrayList<DatasetInstance> getValidationData() {
        return validationData;
    }

    public ArrayList<DatasetInstance> getTestingData() {
        return testingData;
    }

    public boolean isResetStoredCaches() {
        return resetStoredCaches;
    }

    public int getMaxWordSequenceLength() {
        return maxWordSequenceLength;
    }

    public HashMap<String, HashMap<ArrayList<String>, Double>> getValueAlignments() {
        return valueAlignments;
    }

    public HashMap<String, HashMap<ArrayList<Action>, Action>> getPunctuationPatterns() {
        return punctuationPatterns;
    }

    public HashMap<String, SimpleLM> getWordLMsPerPredicate() {
        return wordLMsPerPredicate;
    }

    public String getDataset() {
        return dataset;
    }

    public final void setDataset(String dataset) {
        this.dataset = dataset;
    }

    public int getMaxContentSequenceLength() {
        return maxContentSequenceLength;
    }

    public void setMaxContentSequenceLength(int maxContentSequenceLength) {
        this.maxContentSequenceLength = maxContentSequenceLength;
    }

    public void setMaxWordSequenceLength(int maxWordSequenceLength) {
        this.maxWordSequenceLength = maxWordSequenceLength;
    }

    public boolean getAveraging() {
        return averaging;
    }

    public final void setAveraging(boolean averaging) {
        this.averaging = averaging;
    }

    public boolean getShuffling() {
        return shuffling;
    }

    public final void setShuffling(boolean shuffling) {
        this.shuffling = shuffling;
    }

    public int getRounds() {
        return rounds;
    }

    public final void setRounds(int rounds) {
        this.rounds = rounds;
    }

    public Double getInitialTrainingParam() {
        return initialTrainingParam;
    }

    public final void setInitialTrainingParam(Double initialTrainingParam) {
        this.initialTrainingParam = initialTrainingParam;
    }

    public Double getAdditionalTrainingParam() {
        return additionalTrainingParam;
    }

    public final void setAdditionalTrainingParam(Double additionalTrainingParam) {
        this.additionalTrainingParam = additionalTrainingParam;
    }

    public boolean isAdapt() {
        return adapt;
    }

    public final void setAdapt(boolean adapt) {
        this.adapt = adapt;
    }

    public final void setResetStoredCaches(boolean resetStoredCaches) {
        this.resetStoredCaches = resetStoredCaches;
    }

    public String getPerformEvaluationOn() {
        return performEvaluationOn;
    }

    public final void setPerformEvaluationOn(String performEvaluationOn) {
        this.performEvaluationOn = performEvaluationOn;
    }

    public String getUseAlignments() {
        return useAlignments;
    }

    public final void setUseAlignments(String useAlignments) {
        this.useAlignments = useAlignments;
    }

    public boolean isCalculateResultsPerPredicate() {
        return calculateResultsPerPredicate;
    }

    public final void setCalculateResultsPerPredicate(boolean calculateResultsPerPredicate) {
        this.calculateResultsPerPredicate = calculateResultsPerPredicate;
    }

    public ArrayList<ArrayList<String>> getObservedAttrValueSequences() {
        return observedAttrValueSequences;
    }

    public ArrayList<String> getPredicates() {
        return predicates;
    }

    public HashMap<String, HashSet<String>> getAttributes() {
        return attributes;
    }

    public HashMap<String, HashSet<String>> getAttributeValuePairs() {
        return attributeValuePairs;
    }

    public HashMap<String, ArrayList<DatasetInstance>> getDatasetInstances() {
        return datasetInstances;
    }

    public HashMap<String, ArrayList<Instance>> getPredicateContentTrainingData() {
        return predicateContentTrainingData;
    }

    public HashMap<String, HashMap<String, ArrayList<Instance>>> getPredicateWordTrainingData() {
        return predicateWordTrainingData;
    }

    public HashMap<String, SimpleLM> getContentLMsPerPredicate() {
        return contentLMsPerPredicate;
    }

    public HashMap<String, String> getCompositeSuffixesInData() {
        return compositeWordsInData;
    }

    public void setTrainingData(ArrayList<DatasetInstance> trainingData) {
        this.trainingData = trainingData;
    }

    public void setValidationData(ArrayList<DatasetInstance> validationData) {
        this.validationData = validationData;
    }

    public void setTestingData(ArrayList<DatasetInstance> testingData) {
        this.testingData = testingData;
    }

    public void setContentLMsPerPredicate(HashMap<String, SimpleLM> contentLMsPerPredicate) {
        this.contentLMsPerPredicate = contentLMsPerPredicate;
    }

    public void setWordLMsPerPredicate(HashMap<String, SimpleLM> wordLMsPerPredicate) {
        this.wordLMsPerPredicate = wordLMsPerPredicate;
    }

    public void setValueAlignments(HashMap<String, HashMap<ArrayList<String>, Double>> valueAlignments) {
        this.valueAlignments = valueAlignments;
    }

    public void setObservedAttrValueSequences(ArrayList<ArrayList<String>> observedAttrValueSequences) {
        this.observedAttrValueSequences = observedAttrValueSequences;
    }

    public void setPredicates(ArrayList<String> predicates) {
        this.predicates = predicates;
    }

    public void setAttributes(HashMap<String, HashSet<String>> attributes) {
        this.attributes = attributes;
    }

    public void setAttributeValuePairs(HashMap<String, HashSet<String>> attributeValuePairs) {
        this.attributeValuePairs = attributeValuePairs;
    }

    public void setDatasetInstances(HashMap<String, ArrayList<DatasetInstance>> datasetInstances) {
        this.datasetInstances = datasetInstances;
    }

    public void setPunctuationPatterns(HashMap<String, HashMap<ArrayList<Action>, Action>> punctuationPatterns) {
        this.punctuationPatterns = punctuationPatterns;
    }

    public void setPredicateContentTrainingData(HashMap<String, ArrayList<Instance>> predicateContentTrainingData) {
        this.predicateContentTrainingData = predicateContentTrainingData;
    }

    public void setPredicateWordTrainingData(
            HashMap<String, HashMap<String, ArrayList<Instance>>> predicateWordTrainingData) {
        this.predicateWordTrainingData = predicateWordTrainingData;
    }

    public void setCompositeWordsInData(HashMap<String, String> compositeWordsInData) {
        this.compositeWordsInData = compositeWordsInData;
    }

    public HashMap<String, HashSet<String>> getAvailableContentActions() {
        return availableContentActions;
    }

    public void setAvailableContentActions(HashMap<String, HashSet<String>> availableContentActions) {
        this.availableContentActions = availableContentActions;
    }

    public HashMap<String, HashMap<String, HashSet<Action>>> getAvailableWordActions() {
        return availableWordActions;
    }

    public void setAvailableWordActions(HashMap<String, HashMap<String, HashSet<Action>>> availableWordActions) {
        this.availableWordActions = availableWordActions;
    }
}