Java tutorial
/* * ****************************************************************************** * Copyright 2016 * Copyright (c) 2016 Technische Universitt Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * **************************************************************************** */ package tudarmstadt.lt.ABSentiment.training.util; import de.bwaldvogel.liblinear.Feature; import de.bwaldvogel.liblinear.Linear; import de.bwaldvogel.liblinear.Model; import de.bwaldvogel.liblinear.Problem; import org.apache.commons.lang3.StringUtils; import org.apache.uima.jcas.JCas; import org.datavec.api.records.reader.RecordReader; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import tudarmstadt.lt.ABSentiment.featureExtractor.*; import tudarmstadt.lt.ABSentiment.featureExtractor.util.ConfusionMatrix; import tudarmstadt.lt.ABSentiment.reader.*; import tudarmstadt.lt.ABSentiment.type.Document; import tudarmstadt.lt.ABSentiment.type.Sentence; import tudarmstadt.lt.ABSentiment.uimahelper.Preprocessor; import java.io.*; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Vector; /** * Generates a common training/testing instance as a Document/Feature matrix for training and testing. * Created by abhishek on 19/5/17. */ public class ProblemBuilder { protected static InputReader fr; protected static Preprocessor preprocessor = new Preprocessor(true); protected static String configurationfile; private static Integer maxLabelId = -1; private static int featureCount = 0; protected static boolean useCoarseLabels = false; protected static String language; private static String format; protected static boolean semeval16 = false; protected static String trainFile; protected static String testFile; protected static String predictionFile; protected static String labelMappingsFileSentiment; protected static String labelMappingsFileRelevance; protected static String labelMappingsFileAspect; protected static String labelMappingsFileAspectCoarse; protected static String featureOutputFile; protected static String featureStatisticsFile; protected static String idfGazeteerFile; protected static String idfFile; protected static String relevanceModel; protected static String aspectModel; protected static String aspectCoarseModel; protected static String sentimentModel; protected static String crfModelFolder; protected static String corpusFile; protected static String maxLengthFile; protected static String relevanceIdfFile; protected static String sentimentIdfFile; protected static String aspectIdfFile; protected static String aspectCoarseIdfFile; protected static String positiveGazeteerFile; protected static String negativeGazeteerFile; protected static String polarityLexiconFile; protected static String aggregateGazeteerFile; protected static String DTConfigurationFile; protected static String missingWordsFile; protected static String DTExpansionFile; protected static String DTfile; protected static String gloveFile; protected static String w2vFile; protected static String weightedIdfFile; protected static String weightedW2vFile; protected static String weightedGloveFile; protected static HashMap<String, Integer> labelMappings = new HashMap<>(); protected static HashMap<Integer, String> labelLookup = new HashMap<>(); protected static ConfusionMatrix confusionMatrix; protected static ArrayList<String> allLabels = new ArrayList<>(); /** * Loads a file and initializes all the variables present in the configuration file. * @param configurationFile path to a file containing the variable name and their initialization */ protected static void initialise(String configurationFile) { language = null; format = null; idfFile = null; positiveGazeteerFile = null; negativeGazeteerFile = null; gloveFile = null; w2vFile = null; trainFile = null; testFile = null; featureOutputFile = null; predictionFile = null; labelMappingsFileSentiment = null; labelMappingsFileRelevance = null; labelMappingsFileAspect = null; labelMappingsFileAspectCoarse = null; relevanceModel = null; aspectModel = null; aspectCoarseModel = null; sentimentModel = null; crfModelFolder = null; missingWordsFile = null; DTExpansionFile = null; weightedW2vFile = null; weightedGloveFile = null; weightedIdfFile = null; polarityLexiconFile = null; aggregateGazeteerFile = null; DTConfigurationFile = null; DTfile = null; corpusFile = null; maxLengthFile = null; relevanceIdfFile = null; sentimentIdfFile = null; aspectIdfFile = null; aspectCoarseIdfFile = null; Configuration config = new Configuration(); HashMap<String, String> fileLocation; fileLocation = config.readConfigurationFile(configurationFile); for (HashMap.Entry<String, String> entry : fileLocation.entrySet()) { if (entry.getKey().equals("language")) { language = entry.getValue(); } else if (entry.getKey().equals("format")) { format = entry.getValue(); if (format.compareTo("semeval16") == 0) { semeval16 = true; } } else if (entry.getKey().equals("idfFile")) { idfFile = entry.getValue(); } else if (entry.getKey().equals("idfGazeteerFile")) { idfGazeteerFile = entry.getValue(); } else if (entry.getKey().equals("positiveGazeteerFile")) { positiveGazeteerFile = entry.getValue(); } else if (entry.getKey().equals("negativeGazeteerFile")) { negativeGazeteerFile = entry.getValue(); } else if (entry.getKey().equals("gloveFile")) { gloveFile = entry.getValue(); } else if (entry.getKey().equals("w2vFile")) { w2vFile = entry.getValue(); } else if (entry.getKey().equals("trainFile")) { trainFile = entry.getValue(); } else if (entry.getKey().equals("testFile")) { testFile = entry.getValue(); } else if (entry.getKey().equals("featureOutputFile")) { featureOutputFile = entry.getValue(); } else if (entry.getKey().equals("predictionFile")) { predictionFile = entry.getValue(); } else if (entry.getKey().equals("relevanceModel")) { relevanceModel = entry.getValue(); labelMappingsFileRelevance = entry.getValue() + "_label_mappings.tsv"; } else if (entry.getKey().equals("aspectModel")) { aspectModel = entry.getValue(); labelMappingsFileAspect = entry.getValue() + "_label_mappings.tsv"; } else if (entry.getKey().equals("aspectCoarseModel")) { aspectCoarseModel = entry.getValue(); labelMappingsFileAspectCoarse = entry.getValue() + "_label_mappings.tsv"; } else if (entry.getKey().equals("sentimentModel")) { sentimentModel = entry.getValue(); labelMappingsFileSentiment = entry.getValue() + "_label_mappings.tsv"; } else if (entry.getKey().equals("crfModelFolder")) { crfModelFolder = entry.getValue(); if (!crfModelFolder.endsWith("/")) { crfModelFolder.concat("/"); } } else if (entry.getKey().equals("missingWordsFile")) { missingWordsFile = entry.getValue(); } else if (entry.getKey().equals("DTExpansionFile")) { DTExpansionFile = entry.getValue(); } else if (entry.getKey().equals("weightedW2vFile")) { weightedW2vFile = entry.getValue(); } else if (entry.getKey().equals("weightedGloveFile")) { weightedGloveFile = entry.getValue(); } else if (entry.getKey().equals("weightedIdfFile")) { weightedIdfFile = entry.getValue(); } else if (entry.getKey().equals("polarityLexiconFile")) { polarityLexiconFile = entry.getValue(); } else if (entry.getKey().equals("aggregateGazeteerFile")) { aggregateGazeteerFile = entry.getValue(); } else if (entry.getKey().equals("DTConfigurationFile")) { DTConfigurationFile = entry.getValue(); } else if (entry.getKey().equals(("DTfile"))) { DTfile = entry.getValue(); } else if (entry.getKey().equals(("corpus"))) { corpusFile = entry.getValue(); } else if (entry.getKey().equals(("maxLengthFile"))) { maxLengthFile = entry.getValue(); } else if (entry.getKey().equals(("relIdfTerms"))) { relevanceIdfFile = entry.getValue(); } else if (entry.getKey().equals(("sentIdfTerms"))) { sentimentIdfFile = entry.getValue(); } else if (entry.getKey().equals(("aspectIdfTerms"))) { aspectIdfFile = entry.getValue(); } else if (entry.getKey().equals(("aspectCoarseIdfTerms"))) { aspectCoarseIdfFile = entry.getValue(); } } } /** * Computes a feature vector out of all the feature name specified in the configuration file * @return a Vector containing all the specified feature */ protected static Vector<FeatureExtractor> loadFeatureExtractors(String type) { int offset = 1; Vector<FeatureExtractor> features = new Vector<>(); if (idfFile != null) { FeatureExtractor tfidf = new TfIdfFeature(idfFile, offset); offset += tfidf.getFeatureCount(); features.add(tfidf); } if (type.compareTo("relevance") == 0 && relevanceIdfFile != null) { FeatureExtractor gazeteerIdf = new GazetteerFeature(relevanceIdfFile, offset); offset += gazeteerIdf.getFeatureCount(); features.add(gazeteerIdf); } else if (type.compareTo("sentiment") == 0 && sentimentIdfFile != null) { FeatureExtractor gazeteerIdf = new GazetteerFeature(sentimentIdfFile, offset); offset += gazeteerIdf.getFeatureCount(); features.add(gazeteerIdf); } else if (type.compareTo("aspect") == 0 && aspectIdfFile != null) { FeatureExtractor gazeteerIdf = new GazetteerFeature(aspectIdfFile, offset); offset += gazeteerIdf.getFeatureCount(); features.add(gazeteerIdf); } if (positiveGazeteerFile != null) { FeatureExtractor posDict = new AggregatedGazetteerFeature(positiveGazeteerFile, offset); offset += posDict.getFeatureCount(); features.add(posDict); } if (negativeGazeteerFile != null) { FeatureExtractor negDict = new AggregatedGazetteerFeature(negativeGazeteerFile, offset); offset += negDict.getFeatureCount(); features.add(negDict); } if (polarityLexiconFile != null) { FeatureExtractor polarityLexicon = new PolarityLexiconFeature(polarityLexiconFile, offset); offset += polarityLexicon.getFeatureCount(); features.add(polarityLexicon); } if (aggregateGazeteerFile != null) { FeatureExtractor aggregatedGazeteerFeature = new AggregatedGazetteerFeature(aggregateGazeteerFile, offset); offset += aggregatedGazeteerFeature.getFeatureCount(); features.add(aggregatedGazeteerFeature); } if (gloveFile != null) { FeatureExtractor glove = new WordEmbeddingFeature(gloveFile, null, 1, DTExpansionFile, offset); offset += glove.getFeatureCount(); features.add(glove); } if (w2vFile != null) { FeatureExtractor word2vec = new WordEmbeddingFeature(w2vFile, null, 2, DTExpansionFile, offset); offset += word2vec.getFeatureCount(); features.add(word2vec); } if (weightedGloveFile != null && weightedIdfFile != null) { FeatureExtractor word2vec = new WordEmbeddingFeature(weightedGloveFile, weightedIdfFile, 1, DTExpansionFile, offset); offset += word2vec.getFeatureCount(); features.add(word2vec); } if (weightedW2vFile != null && weightedIdfFile != null) { FeatureExtractor word2vec = new WordEmbeddingFeature(weightedW2vFile, weightedIdfFile, 2, DTExpansionFile, offset); offset += word2vec.getFeatureCount(); features.add(word2vec); } return features; } /** * Builds a problem - the input feature matrix, output labels, total number of feature instances and the feature count * @param trainingFile path to the training file * @param features feature vector of all the features specified * @param type * @param ifTraining specifies if this method is used for training or testing */ protected static Problem buildProblem(String trainingFile, Vector<FeatureExtractor> features, String type, Boolean ifTraining) { if (ifTraining) { resetLabelMappings(); } printFeatureStatistics(features); if (trainingFile.endsWith("xml")) { if (semeval16) { fr = new XMLReaderSemEval(trainingFile); } else { fr = new XMLReader(trainingFile); } } else { fr = new TsvReader(trainingFile); } int documentCount = 0; Vector<Double> labels = new Vector<>(); Vector<Feature[]> featureVector = new Vector<>(); Vector<Feature[]> instanceFeatures = null; String[] stringLabel = null; for (Document doc : fr) { for (Sentence sentence : doc.getSentences()) { preprocessor.processText(sentence.getText()); instanceFeatures = applyFeatures(preprocessor.getCas(), features); if (type == null) { stringLabel = sentence.getAspectCategories(); } else if (type.compareTo("relevance") == 0) { stringLabel = sentence.getRelevance(); } else if (type.compareTo("sentiment") == 0) { try { stringLabel = sentence.getSentiment(); } catch (NoSuchFieldException e) { // COMMENT HERE continue; } } else if (type.compareTo("aspect") == 0) { if (useCoarseLabels) { stringLabel = sentence.getAspectCategoriesCoarse(); } else { stringLabel = sentence.getAspectCategories(); } } for (String l : stringLabel) { if (l == null || l.isEmpty()) { continue; } Double label = getLabelId(l); labels.add(label); featureVector.add(combineInstanceFeatures(instanceFeatures)); documentCount++; } } } if (featureOutputFile != null) { saveFeatureVectors(featureOutputFile, featureVector, labels); } Problem problem = new Problem(); problem.l = documentCount; problem.n = featureCount; problem.x = new Feature[documentCount][]; problem.y = new double[documentCount]; for (int i = 0; i < labels.size(); i++) { problem.y[i] = labels.get(i); problem.x[i] = featureVector.get(i); } return problem; } /** * Builds a problem - the input feature matrix, output labels, total number of feature instances and the feature count * @param trainingFile path to the training file * @param features feature vector of all the features specified * @param ifTraining specifies if this method is used for training or testing */ protected static Problem buildProblem(String trainingFile, Vector<FeatureExtractor> features, Boolean ifTraining) { resetLabelMappings(); printFeatureStatistics(features); return buildProblem(trainingFile, features, null, ifTraining); } protected static Vector<Feature[]> applyFeatures(JCas cas, Vector<FeatureExtractor> features) { Vector<Feature[]> instanceFeatures = new Vector<>(); for (FeatureExtractor feature : features) { instanceFeatures.add(feature.extractFeature(cas)); // update the featureCount, the maximal Feature id featureCount = feature.getFeatureCount() + feature.getOffset(); } return instanceFeatures; } protected static Feature[] combineInstanceFeatures(Vector<Feature[]> instanceFeatures) { int length = 0; for (Feature[] f : instanceFeatures) { length += f.length; } Feature[] instance = new Feature[length]; int i = 0; for (Feature[] fa : instanceFeatures) { for (Feature value : fa) { instance[i++] = value; } } return instance; } protected static void saveLabelMappings(String mappingFile) { try { Writer out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(mappingFile), "UTF-8")); for (String label : labelMappings.keySet()) { out.write(labelMappings.get(label) + "\t" + label + "\n"); } out.close(); } catch (Exception e) {//Catch exception if any System.err.println("Error: " + e.getMessage()); } } protected static Double getLabelId(String label) { if (labelMappings.containsKey(label)) { return labelMappings.get(label).doubleValue(); } else { labelMappings.put(label, ++maxLabelId); labelLookup.put(maxLabelId, label); return maxLabelId.doubleValue(); } } protected static String getLabelString(Double labelId) { return labelLookup.get(labelId.intValue()); } protected static void saveFeatureVectors(String featureVectorFile, Vector<Feature[]> featureVector, Vector<Double> labels) { if (featureVectorFile == null) { return; } try { Writer featureOut = new BufferedWriter( new OutputStreamWriter(new FileOutputStream(featureVectorFile), "UTF-8")); for (int i = 0; i < labels.size(); i++) { featureOut.write(labels.get(i).toString()); Feature[] features = featureVector.get(i); for (Feature f : features) { featureOut.write(" " + f.getIndex() + ":" + f.getValue()); } featureOut.write("\n"); } featureOut.close(); } catch (UnsupportedEncodingException | FileNotFoundException e) { e.printStackTrace(); System.exit(1); } catch (IOException e) { e.printStackTrace(); } } protected static void printFeatureStatistics(Vector<FeatureExtractor> features) { if (featureStatisticsFile != null) { try { Writer statisticsOut = new BufferedWriter( new OutputStreamWriter(new FileOutputStream(featureStatisticsFile), "UTF-8")); statisticsOut.write("training set: " + trainFile + "\n"); if (featureStatisticsFile != null) { int start; int end; for (FeatureExtractor feature : features) { start = feature.getOffset(); end = feature.getOffset() + feature.getFeatureCount(); statisticsOut .append(feature.getClass().getCanonicalName() + "\t" + start + "\t" + end + "\n"); } } statisticsOut.close(); } catch (UnsupportedEncodingException | FileNotFoundException e) { e.printStackTrace(); System.exit(1); } catch (IOException e) { e.printStackTrace(); } } } protected static void resetLabelMappings() { labelMappings = new HashMap<>(); labelLookup = new HashMap<>(); maxLabelId = -1; } protected static void loadLabelMappings(String fileName) { try { BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(fileName), "UTF-8")); String line; while ((line = br.readLine()) != null) { String[] catLine = line.split("\\t"); Integer labelId = Integer.parseInt(catLine[0]); labelLookup.put(labelId, catLine[1]); labelMappings.put(catLine[1], labelId); } br.close(); } catch (IOException e) { e.printStackTrace(); System.exit(1); } } protected static INDArray classifyTestSet(String inputFile, Model model, Vector<FeatureExtractor> features, String predictionFile, String type, boolean printResult) { InputReader fr; if (inputFile.endsWith("xml")) { fr = new XMLReader(inputFile); } else { fr = new TsvReader(inputFile); } Writer out = null; Writer featureOut = null; try { OutputStream predStream = new FileOutputStream(predictionFile); out = new OutputStreamWriter(predStream, "UTF-8"); if (featureOutputFile != null) { OutputStream vectorStream = new FileOutputStream(featureOutputFile); featureOut = new OutputStreamWriter(vectorStream, "UTF-8"); } } catch (FileNotFoundException | UnsupportedEncodingException e1) { e1.printStackTrace(); System.exit(1); } Feature[] instance; Vector<Feature[]> instanceFeatures; confusionMatrix = new ConfusionMatrix(); String item; for (int j = 0; j < model.getNrClass(); j++) { item = labelLookup.get(Integer.parseInt(model.getLabels()[j] + "")); confusionMatrix.addLabel(item); allLabels.add(item); } ArrayList<double[]> probability = new ArrayList<>(); confusionMatrix.createMatrix(); for (Document doc : fr) { for (Sentence sentence : doc.getSentences()) { int i = 0; preprocessor.processText(sentence.getText()); instanceFeatures = applyFeatures(preprocessor.getCas(), features); Double prediction; instance = combineInstanceFeatures(instanceFeatures); double[] prob_estimates = new double[model.getNrClass()]; prediction = Linear.predictProbability(model, instance, prob_estimates); probability.add(prob_estimates); try { out.write(sentence.getId() + "\t" + sentence.getText() + "\t"); String goldLabel = null; String predictedLabel = labelLookup.get(prediction.intValue()); if (type.compareTo("relevance") == 0) { goldLabel = sentence.getRelevance()[0]; confusionMatrix.updateMatrix(predictedLabel, goldLabel); } else if (type.compareTo("sentiment") == 0) { try { while (i < sentence.getSentiment().length) { goldLabel = sentence.getSentiment()[i++]; confusionMatrix.updateMatrix(predictedLabel, goldLabel); } } catch (NoSuchFieldException e) { } } else if (useCoarseLabels) { out.append(StringUtils.join(sentence.getAspectCategoriesCoarse(), " ")); goldLabel = StringUtils.join(sentence.getAspectCategoriesCoarse(), " "); confusionMatrix.updateMatrix(predictedLabel, goldLabel); } else { out.append(StringUtils.join(sentence.getAspectCategories(), " ")); goldLabel = StringUtils.join(sentence.getAspectCategories(), " "); confusionMatrix.updateMatrix(predictedLabel, goldLabel); } out.append("\t").append(labelLookup.get(prediction.intValue())).append("\n"); } catch (IOException e) { e.printStackTrace(); } if (featureOutputFile != null) { String[] labels = sentence.getAspectCategories(); if (useCoarseLabels) { labels = sentence.getAspectCategoriesCoarse(); } for (String label : labels) { try { assert featureOut != null; for (Feature f : instance) { featureOut.write(" " + f.getIndex() + ":" + f.getValue()); } featureOut.write("\n"); } catch (IOException e) { e.printStackTrace(); } } } } } INDArray classificationProbability = Nd4j.zeros(probability.size(), model.getNrClass()); int j = -1; for (double prob_estimates[] : probability) { classificationProbability.putRow(++j, Nd4j.create(prob_estimates)); } try { out.close(); } catch (IOException e) { e.printStackTrace(); } HashMap<String, Float> recall; HashMap<String, Float> precision; HashMap<String, Float> fMeasure; recall = getRecallForAll(); precision = getPrecisionForAll(); fMeasure = getFMeasureForAll(); if (printResult) { System.out.println("Label" + "\t" + "Recall" + "\t" + "Precision" + "\t" + "F Score"); for (String itemLabel : allLabels) { System.out.println(itemLabel + "\t" + recall.get(itemLabel) + "\t" + precision.get(itemLabel) + "\t" + fMeasure.get(itemLabel)); } printFeatureStatistics(features); printConfusionMatrix(); System.out.println("\n"); System.out.println("True positive : " + getTruePositive()); System.out.println("Accuracy : " + getOverallAccuracy()); System.out.println("Overall Precision : " + getOverallPrecision()); System.out.println("Overall Recall : " + getOverallRecall()); System.out.println("Overall FMeasure : " + getOverallFMeasure()); } return classificationProbability; } protected static INDArray classifyTestSet(MultiLayerNetwork model, Problem problem, boolean printResult) { int batchSize = 200; int labelIndex = 0; int numClasses = labelMappings.size(); List<List<Double>> inputFeature = new ArrayList<>(); for (int i = 0; i < problem.l; i++) { Feature[] array = problem.x[i]; Double y = problem.y[i]; ArrayList<Double> newArray = new ArrayList<>(); newArray.add(y); int k = 0; for (int j = 0; j < problem.n; j++) { if (k < array.length) { if (array[k].getIndex() == j) { newArray.add(array[k++].getValue()); } else { newArray.add(0.0); } } } inputFeature.add(newArray); } INDArray classificationProbability = Nd4j.zeros(inputFeature.size(), numClasses); RecordReader recordReader = new ListDoubleRecordReader(); try { recordReader.initialize(new ListDoubleSplit(inputFeature)); } catch (IOException e) { e.printStackTrace(); } catch (InterruptedException e) { e.printStackTrace(); } DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses); Evaluation eval = new Evaluation(numClasses); DataSet ds; int j = -1; while (iterator.hasNext()) { ds = iterator.next(); INDArray output = model.output(ds.getFeatureMatrix()); for (int i = 0; i < output.size(0); i++) { classificationProbability.putRow(++j, output.getRow(i)); } eval.eval(ds.getLabels(), output); } if (printResult) { System.out.println(eval.stats()); } return classificationProbability; } protected static void printConfusionMatrix() { confusionMatrix.printConfusionMatrix(); } protected static double getRecallForLabel(String label) { return confusionMatrix.getRecallForLabel(label); } protected static double getPrecisionForLabel(String label) { return confusionMatrix.getPrecisionForLabel(label); } protected static HashMap<String, Float> getRecallForAll() { return confusionMatrix.getRecallForAllLabels(); } protected static HashMap<String, Float> getPrecisionForAll() { return confusionMatrix.getPrecisionForAllLabels(); } protected static HashMap<String, Float> getFMeasureForAll() { return confusionMatrix.getFMeasureForAllLabels(); } protected static int getTruePositive() { return confusionMatrix.getTruePositive(); } protected static float getOverallAccuracy() { return confusionMatrix.getOverallAccuracy(); } protected static float getOverallRecall() { return confusionMatrix.getOverallRecall(); } protected static float getOverallPrecision() { return confusionMatrix.getOverallPrecision(); } protected static float getOverallFMeasure() { return confusionMatrix.getOverallFMeasure(); } }