Java tutorial
/* * Copyright 2016 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.dkpro.tc.ml.svmhmm.util; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.SortedMap; import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; import org.apache.commons.collections.BidiMap; import org.apache.commons.collections.bidimap.DualTreeBidiMap; import org.apache.commons.csv.CSVFormat; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.dkpro.tc.ml.svmhmm.writer.SVMHMMDataWriter; public final class SVMHMMUtils { /** * File name of serialized mapping from String labels to numbers */ public static final String LABELS_TO_INTEGERS_MAPPING_FILE_NAME = "labelsToIntegersMapping_DualTreeBidiMap.bin"; /** * CSV file comment */ public static final String CSV_COMMENT = "Columns: gold, predicted, token, seqID"; /** * Format of CSV files */ public static final CSVFormat CSV_FORMAT = CSVFormat.DEFAULT.withCommentMarker('#'); /** * Where the gold outcomes, predicted outcomes, and tokens are stored */ public static final String GOLD_PREDICTED_OUTCOMES_CSV = "outcomesGoldPredicted.csv"; private SVMHMMUtils() { // empty } /** * Extract all outcomes from featureVectorsFiles (training, test) that are in LIBSVM format - * each line is a feature vector and the first token is the outcome label * * @param files * files in LIBSVM format * @return set of all unique outcomes * @throws java.io.IOException */ public static SortedSet<String> extractOutcomeLabelsFromFeatureVectorFiles(File... files) throws IOException { SortedSet<String> result = new TreeSet<>(); for (File file : files) { result.addAll(extractOutcomeLabels(file)); } return result; } /** * Maps names to numbers (numbers are required by SVMLight format) * * @param names * names (e.g., features, outcomes) * @return bidirectional map of name:number */ public static BidiMap mapVocabularyToIntegers(SortedSet<String> names) { BidiMap result = new DualTreeBidiMap(); // start numbering from 1 int index = 1; for (String featureName : names) { result.put(featureName, index); index++; } return result; } /** * Creates a new file in the same directory as {@code featureVectorsFile} and replaces the first * token (outcome label) by its corresponding integer number from the bi-di map * * @param featureVectorsFile * file * @param labelsToIntegers * mapping * @return new file */ public static File replaceLabelsWithIntegers(File featureVectorsFile, BidiMap labelsToIntegers) throws IOException { File result = new File(featureVectorsFile.getParent(), "mappedLabelsToInt_" + featureVectorsFile.getName()); PrintWriter pw = new PrintWriter(new FileOutputStream(result)); BufferedReader br = new BufferedReader(new FileReader(featureVectorsFile)); String line = null; while ((line = br.readLine()) != null) { // split on the first whitespaces, keep the rest String[] split = line.split("\\s", 2); String label = split[0]; String remainingContent = split[1]; // find the integer Integer intOutput = (Integer) labelsToIntegers.get(label); // print to the output stream pw.printf("%d %s%n", intOutput, remainingContent); } IOUtils.closeQuietly(pw); IOUtils.closeQuietly(br); return result; } /** * Saves label-integer mapping to a file * * @param mapping * mapping * @param outputFile * file * @throws IOException */ public static void saveMapping(BidiMap mapping, File outputFile) throws IOException { ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(outputFile)); objectOutputStream.writeObject(mapping); IOUtils.closeQuietly(objectOutputStream); } /** * Saves the feature mapping to readable format, each line is a feature id and feature name, * sorted by feature id * * @param mapping * mapping (name:id) * @param outputFile * output file * @throws IOException */ public static void saveMappingTextFormat(BidiMap mapping, File outputFile) throws IOException { PrintWriter pw = new PrintWriter(new FileOutputStream(outputFile)); // sort values (feature indexes) @SuppressWarnings("unchecked") SortedSet<Object> featureIndexes = new TreeSet<Object>(mapping.values()); for (Object featureIndex : featureIndexes) { pw.printf(Locale.ENGLISH, "%5d %s%n", (int) featureIndex, mapping.getKey(featureIndex).toString()); } IOUtils.closeQuietly(pw); } /** * Loads a serialized BidiMap from file * * @param inputFile * input file * @return BidiMap * @throws IOException */ public static BidiMap loadMapping(File inputFile) throws IOException { ObjectInputStream inputStream = new ObjectInputStream(new FileInputStream(inputFile)); try { return (BidiMap) inputStream.readObject(); } catch (ClassNotFoundException e) { throw new IOException(e); } finally { IOUtils.closeQuietly(inputStream); } } /** * Extracts the outcome labels from the file; it corresponds to the first token on each line. * * @param featureVectorsFile * featureVectors file * @return list of outcome labels * @throws IOException */ public static List<String> extractOutcomeLabels(File featureVectorsFile) throws IOException { List<String> result = new ArrayList<>(); List<String> lines = FileUtils.readLines(featureVectorsFile); for (String line : lines) { String label = line.split("\\s")[0]; result.add(label); } return result; } /** * Reads the featureVectorsFile and splits comment on each line into a list of strings, i.e. * "TAG qid:4 1:1 2:1 4:2 # token TAG 4" produces "token", "TAG", "4" * * @param featureVectorsFileStream * featureVectors file stream * @return list (for each line) of list of comment parts * @throws IOException */ protected static Iterator<List<String>> extractComments(final File featureVectorsFileStream) throws IOException { return new CommentsIterator(featureVectorsFileStream); } /** * Extracts original tokens that are stored in the comment part of the featureVectorsFile * * @param featureVectorsFile * featureVectors file * @return list of original tokens * @throws IOException */ public static List<String> extractOriginalTokens(File featureVectorsFile) throws IOException { List<String> result = new ArrayList<>(); Iterator<List<String>> comments = extractComments(featureVectorsFile); try { while (comments.hasNext()) { List<String> comment = comments.next(); // original token is the first one in comments result.add(comment.get(2)); } } catch (IndexOutOfBoundsException e) { throw new IllegalArgumentException( "SvmHmm requires the feature [" + OriginalTextHolderFeatureExtractor.class.getSimpleName() + "] to be in the feature space otherwise this exception might be thrown.", e); } return result; } /** * Reads the prediction file (each line is a integer) and converts them into original outcome * labels using the mapping provided by the bi-directional map * * @param predictionsFile * predictions from classifier * @param labelsToIntegersMapping * mapping outcomeLabel:integer * @return list of outcome labels * @throws IOException */ public static List<String> extractOutcomeLabelsFromPredictions(File predictionsFile, BidiMap labelsToIntegersMapping) throws IOException { List<String> result = new ArrayList<>(); for (String line : FileUtils.readLines(predictionsFile)) { Integer intLabel = Integer.valueOf(line); String outcomeLabel = (String) labelsToIntegersMapping.getKey(intLabel); result.add(outcomeLabel); } return result; } /** * Returns a list of original sequence IDs extracted from comments * * @param featureVectorsFile * featureVectors file * @return list of integers * @throws IOException */ public static List<Integer> extractOriginalSequenceIDs(File featureVectorsFile) throws IOException { List<Integer> result = new ArrayList<>(); Iterator<List<String>> comments = extractComments(featureVectorsFile); while (comments.hasNext()) { List<String> comment = comments.next(); // sequence number is the third token in the comment token result.add(Integer.valueOf(comment.get(1))); } return result; } public static List<SortedMap<String, String>> extractMetaDataFeatures(File featureVectorsFile) throws IOException { InputStream inputStream = new FileInputStream(featureVectorsFile); List<SortedMap<String, String>> result = new ArrayList<>(); Iterator<List<String>> allComments = extractComments(featureVectorsFile); while (allComments.hasNext()) { List<String> instanceComments = allComments.next(); SortedMap<String, String> instanceResult = new TreeMap<>(); for (String comment : instanceComments) { if (comment.startsWith(SVMHMMDataWriter.META_DATA_FEATURE_PREFIX)) { String[] split = comment.split(":"); String key = split[0]; String value = split[1]; instanceResult.put(key, value); } } result.add(instanceResult); } IOUtils.closeQuietly(inputStream); return result; } public static double getParameterC(List<String> classificationArguments) { double p = getDoubleParameter(classificationArguments, "-c"); if (p == -1) { return 1.0; } if (p < 0) { throw new IllegalArgumentException("Parameter C is < 0"); } return p; } public static double getParameterEpsilon(List<String> classificationArguments) { double p = getDoubleParameter(classificationArguments, "-e"); if (p == -1) { return 0.5; } if (p < 0) { throw new IllegalArgumentException("Epsilon is < 0"); } return p; } public static int getParameterOrderT_dependencyOfTransitions(List<String> classificationArguments) { int p = getIntegerParameter(classificationArguments, "-t"); if (p == -1) { return 1; } if (p < 0) { throw new IllegalArgumentException("Parameter order-t is < 0"); } return p; } public static int getParameterOrderE_dependencyOfEmissions(List<String> classificationArguments) { int p = getIntegerParameter(classificationArguments, "-m"); if (p == -1) { return 0; } if (p != 0 && p != 1) { throw new IllegalArgumentException( "Dependency of emission parameter [-e] has to be either zero (default) or one"); } if (p < 0) { throw new IllegalArgumentException("Parameter order-e is < 0"); } return p; } public static int getParameterBeamWidth(List<String> classificationArguments) { int p = getIntegerParameter(classificationArguments, "-b"); if (p == -1) { return 0; } if (p < 0) { throw new IllegalArgumentException("Parameter beam width is < 0"); } return p; } static int getIntegerParameter(List<String> classificationArguments, String sw) { for (int i = 0; i < classificationArguments.size(); i++) { String e = classificationArguments.get(i); if (e.equals(sw)) { if (i + 1 >= classificationArguments.size()) { throw new IllegalArgumentException("Found switch [" + sw + "] but no value was specified"); } Integer val = 0; String stringVal = classificationArguments.get(i + 1); try { val = Integer.valueOf(stringVal); } catch (Exception ex) { throw new IllegalArgumentException( "The provided parameter for [" + sw + "] is not an integer value [" + stringVal + "]"); } return val; } } return -1; } static double getDoubleParameter(List<String> classificationArguments, String sw) { for (int i = 0; i < classificationArguments.size(); i++) { String e = classificationArguments.get(i); if (e.equals(sw)) { if (i + 1 >= classificationArguments.size()) { throw new IllegalArgumentException("Found switch [" + sw + "] but no value was specified"); } Double val = 0.0; String stringVal = classificationArguments.get(i + 1); try { val = Double.valueOf(stringVal); } catch (Exception ex) { throw new IllegalArgumentException( "The provided parameter for [" + sw + "] is not a double value [" + stringVal + "]"); } return val; } } return -1; } }