Java tutorial
/** * TokenizerApplication.java * * Copyright (c) 2015, JULIE Lab. * All rights reserved. This program and the accompanying materials * are made available under the terms of the GNU Lesser General Public License (LGPL) v3.0 * * Author: tomanek * * Current version: 2.0 Since version: 1.0 * * Creation date: Aug 01, 2006 * * The user interface (command line version) for the JULIE Token Boundary * Detector. Includes training, prediction, file format check, and evaluation. * * Some info on logging: to control mallet's logging, please use the * logging.properties file via -Djava.util.logging.config.file * **/ package de.julielab.jtbd; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.io.ObjectInputStream; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Random; import java.util.zip.GZIPInputStream; import org.apache.commons.io.FileUtils; import cc.mallet.fst.CRF; import cc.mallet.pipe.Pipe; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.LabelSequence; public class TokenizerApplication { private static class EvalResult { double ACC; double fp; double fn; double corrDecisions; double getF() { return (2 * getR() * getP()) / (getR() + getP()); } double getP() { return corrDecisions / (corrDecisions + fp); } double getR() { return corrDecisions / (corrDecisions + fn); } } /** * 90-10 split evaluation * * @param orgSentencesFile * @param tokSentencesFile * @param errors * @param predictions * @return */ private static EvalResult do9010Evaluation(final File orgSentencesFile, final File tokSentencesFile, final ArrayList<String> errors, final ArrayList<String> predictions) { final ArrayList<String> orgSentences = readFile(orgSentencesFile); final ArrayList<String> tokSentences = readFile(tokSentencesFile); final long seed = 1; Collections.shuffle(orgSentences, new Random(seed)); Collections.shuffle(tokSentences, new Random(seed)); final int sizeAll = orgSentences.size(); final int sizeTest = (int) (sizeAll * 0.1); final int sizeTrain = sizeAll - sizeTest; if (sizeTest == 0) { System.err.println("Error: no test files for this split."); System.exit(-1); } System.out.println("all: " + sizeAll + "\ttrain: " + sizeTrain + "\t" + "test: " + sizeTest); final ArrayList<String> trainOrgSentences = new ArrayList<String>(); final ArrayList<String> trainTokSentences = new ArrayList<String>(); final ArrayList<String> predictOrgSentences = new ArrayList<String>(); final ArrayList<String> predictTokSentences = new ArrayList<String>(); for (int i = 0; i < sizeTrain; i++) { trainOrgSentences.add(orgSentences.get(i)); trainTokSentences.add(tokSentences.get(i)); } for (int i = sizeTrain; i < sizeAll; i++) { predictOrgSentences.add(orgSentences.get(i)); predictTokSentences.add(tokSentences.get(i)); } // System.out.println(trainOrgSentences.toString()); // System.out.println(trainTokSentences.toString()); // System.out.println(predictOrgSentences.toString()); // System.out.println(predictTokSentences.toString()); return doEvaluation(trainOrgSentences, trainTokSentences, predictOrgSentences, predictTokSentences, predictions, errors); } /** * check the file format * * @param orgSentencesFile * @param tokSentencesFile */ private static void doCheck(final File orgSentencesFile, final File tokSentencesFile) { final Tokenizer tokenizer = new Tokenizer(); System.out.println("checking on files: \n * " + orgSentencesFile.toString() + "\n * " + tokSentencesFile.toString() + "\n"); final ArrayList<String> orgSentences = readFile(orgSentencesFile); final ArrayList<String> tokSentences = readFile(tokSentencesFile); final InstanceList trainData = tokenizer.makeTrainingData(orgSentences, tokSentences); final Pipe myPipe = trainData.getPipe(); // System.out.println("\n" + myPipe.getDataAlphabet().toString()); System.out.println("\n\n\n# Features resulting from training data: " + myPipe.getDataAlphabet().size()); System.out.println("(critical sentences were omitted for feature generation)"); System.out.println("Done."); } /** * perform cross validation * * @param n * number of splits * @param orgSentencesFile * @param tokSentencesFile * @param errors * @param predictions * @return */ private static double doCrossEvaluation(final int n, final File orgSentencesFile, final File tokSentencesFile, final ArrayList<String> errors, final ArrayList<String> predictions) { final ArrayList<String> orgSentences = readFile(orgSentencesFile); final ArrayList<String> tokSentences = readFile(tokSentencesFile); final long seed = 1; Collections.shuffle(orgSentences, new Random(seed)); Collections.shuffle(tokSentences, new Random(seed)); int pos = 0; final int sizeRound = orgSentences.size() / n; final int sizeAll = orgSentences.size(); final int sizeLastRound = sizeRound + (sizeAll % n); System.out.println("number of files in directory: " + sizeAll); System.out.println("size of each/last round: " + sizeRound + "/" + sizeLastRound); System.out.println(); final EvalResult[] er = new EvalResult[n]; // double avgAcc = 0; double avgF = 0; for (int i = 0; i < n; i++) { // in each round final ArrayList<String> predictOrgSentences = new ArrayList<String>(); final ArrayList<String> predictTokSentences = new ArrayList<String>(); final ArrayList<String> trainOrgSentences = new ArrayList<String>(); final ArrayList<String> trainTokSentences = new ArrayList<String>(); if (i == (n - 1)) { // last round for (int j = 0; j < orgSentences.size(); j++) if (j < pos) { trainOrgSentences.add(orgSentences.get(j)); trainTokSentences.add(tokSentences.get(j)); } else { predictOrgSentences.add(orgSentences.get(j)); predictTokSentences.add(tokSentences.get(j)); } } else { // other rounds for (int j = 0; j < orgSentences.size(); j++) if ((j < pos) || (j >= (pos + sizeRound))) { // System.out.println(j + " - add to train"); trainOrgSentences.add(orgSentences.get(j)); trainTokSentences.add(tokSentences.get(j)); } else { predictOrgSentences.add(orgSentences.get(j)); predictTokSentences.add(tokSentences.get(j)); } pos += sizeRound; } // now evaluate for this round System.out.println("training size: " + trainOrgSentences.size()); System.out.println("prediction size: " + predictOrgSentences.size()); er[i] = doEvaluation(trainOrgSentences, trainTokSentences, predictOrgSentences, predictTokSentences, predictions, errors); } final DecimalFormat df = new DecimalFormat("0.000"); for (int i = 0; i < er.length; i++) { avgAcc += er[i].ACC; avgF += er[i].getF(); System.out.println("ACC in round " + i + ": " + df.format(er[i].ACC)); } avgAcc = avgAcc / n; avgF = avgF / n; System.out.println("\n\n------------------------------------"); System.out.println("avg accuracy: " + df.format(avgAcc)); System.out.println("avg F-score: " + df.format(avgF)); System.out.println("------------------------------------"); return avgAcc; } /** * general evaluation function, is called from doCrossEvaluation or * do9010Evaluation. * * @param crf * the crf model * @param predictOrgSentences * @param predictTokSentences * @param errors * @param predictions * @return */ public static EvalResult doEvaluation(final ArrayList<String> trainOrgSentences, final ArrayList<String> trainTokSentences, final ArrayList<String> predictOrgSentences, final ArrayList<String> predictTokSentences, final ArrayList<String> errors, final ArrayList<String> predictions) { final Tokenizer tokenizer = new Tokenizer(); // 1. training final InstanceList trainData = tokenizer.makeTrainingData(trainOrgSentences, trainTokSentences); final Pipe myPipe = trainData.getPipe(); System.out.println("training model..."); tokenizer.train(trainData, myPipe); return doEvaluation(tokenizer.getModel(), predictOrgSentences, predictTokSentences, errors, predictions); } /** * general evaluation function, is called from doEvaluation * * @param crf * the crf model * @param predictOrgSentences * @param predictTokSentences * @param errors * @param predictions * @return */ @SuppressWarnings("unchecked") private static EvalResult doEvaluation(final CRF crf, final ArrayList<String> predictOrgSentences, final ArrayList<String> predictTokSentences, final ArrayList<String> errors, final ArrayList<String> predictions) { final Tokenizer tokenizer = new Tokenizer(); tokenizer.setModel(crf); // 2. prediction final InstanceList predData = tokenizer.makePredictionData(predictOrgSentences, predictTokSentences); int nrDecisions = 0; int corrDecisions = 0; int fp = 0; int fn = 0; for (int i = 0; i < predData.size(); i++) { final String orgSentence = predictOrgSentences.get(i); final String tokSentence = predictTokSentences.get(i); String sentenceBoundary = orgSentence.substring(orgSentence.length() - 1, orgSentence.length()); final Instance inst = predData.get(i); ArrayList<Unit> units = null; units = tokenizer.predict(inst); // 3. evaluation final ArrayList<String> orgLabels = tokenizer .getLabelsFromLabelSequence((LabelSequence) inst.getTarget()); final ArrayList<String> wSpaces = (ArrayList<String>) inst.getSource(); String sentence = ""; int localDec = 0; int localCorr = 0; boolean hasError = false; for (int j = 0; j < units.size(); j++) { final String sp = (units.get(j).label.equals("P")) ? " " : ""; sentence += units.get(j).rep + sp; if (!wSpaces.get(j).equals("WS") && (j < (units.size() - 1))) { // this is a critical split decision... count! // do not count last label (i.e. sentence boundary is not // critical) localDec++; // compare labels here if (orgLabels.get(j).equals(units.get(j).label)) localCorr++; else { hasError = true; if (orgLabels.get(j).equals("P") && units.get(j).label.equals("N")) fn++; if (orgLabels.get(j).equals("N") && units.get(j).label.equals("P")) fp++; errors.add("@" + orgLabels.get(j) + "->" + units.get(j).label); errors.add(tokenizer.showErrorContext(j, units, orgLabels)); } } } nrDecisions += localDec; corrDecisions += localCorr; // System.out.println("local critical: " + localDec); // System.out.println("local correct: " + localCorr); // System.out.println(" IN: " + orgSentence); // System.out.println("PRED: " + sentence + sentenceBoundary); // System.out.println("GOLD: " + tokSentence); // System.out.println(); if (!sentence.substring(sentence.length() - 1, sentence.length()).equals(" ")) sentenceBoundary = " " + sentenceBoundary; predictions.add(sentence + sentenceBoundary); if (hasError) { errors.add(sentence + sentenceBoundary); errors.add(tokSentence); errors.add("\n"); } } final double ACC = (corrDecisions / (double) nrDecisions); final EvalResult er = new EvalResult(); er.ACC = ACC; er.fn = fn; er.fp = fp; er.corrDecisions = corrDecisions; System.out.println("\n* ------------------------------------"); System.out.println("* critical decisions: " + nrDecisions); System.out.println("* correct decisions: " + corrDecisions); System.out.println("* fp: " + fp); System.out.println("* fn: " + fn); System.out.println("* R: " + er.getR()); System.out.println("* P: " + er.getP()); System.out.println("* F: " + er.getF()); System.out.println("* ACC = " + ACC); System.out.println("* ------------------------------------\n"); // return ACC; return er; } /** * tokenize documents * * @param inDir * the directory with the documents to be tokenized * @param outDir * the directory where the tokenized documents should be written * to * @param modelFile * the model to use for tokenization * @throws IOException */ public static void doPrediction(final File inDir, final File outDir, final String modelFilename) throws IOException { final Tokenizer tokenizer = new Tokenizer(); try { tokenizer.readModel(new File(modelFilename)); } catch (final Exception e) { e.printStackTrace(); } // get list of all files in directory final File[] predictOrgFiles = inDir.listFiles(); // loop over all files for (final File predictOrgFile : predictOrgFiles) { final long start = System.currentTimeMillis(); List<String> orgSentences = FileUtils.readLines(predictOrgFile, "utf-8"); //readFile(predictOrgFile); //TODO erik fragen was er davon htl ArrayList<String> tokSentences = new ArrayList<String>(); ArrayList<String> predictions = new ArrayList<String>(); // force empty labels for (int j = 0; j < orgSentences.size(); j++) tokSentences.add(""); // make prediction data InstanceList predData = tokenizer.makePredictionData(orgSentences, tokSentences); // predict for (int i = 0; i < predData.size(); i++) { final String orgSentence = orgSentences.get(i); final char lastChar = orgSentence.charAt(orgSentence.length() - 1); final Instance inst = predData.get(i); ArrayList<Unit> units = null; units = tokenizer.predict(inst); // ArrayList<Unit> units = (ArrayList) inst.getName(); String sentence = ""; for (int j = 0; j < units.size(); j++) { final String sp = (units.get(j).label.equals("P")) ? " " : ""; sentence += units.get(j).rep + sp; } if (EOSSymbols.contains(lastChar)) sentence += " " + lastChar; sentence = sentence.replaceAll(" +", " "); predictions.add(sentence); } // write predictions into file final String fName = predictOrgFile.toString(); final String newfName = fName.substring(fName.lastIndexOf("/") + 1, fName.length()); final File fNew = new File(outDir.toString() + "/" + newfName); writeFile(predictions, fNew); // System.out.println("\ntokenized sentences written to: " + // fNew.toString()); // set all arraylists to null so that GC can get them orgSentences = null; tokSentences = null; predictions = null; predData = null; System.gc(); final long stop = System.currentTimeMillis(); System.out.println("took: " + (stop - start)); } // out loop over files System.out.println("Tokenized texts written to: " + outDir.toString()); } /** * train a model * * @param orgSentencesFile * @param tokSentencesFile * @param modelFilename */ public static void doTraining(final File orgSentencesFile, final File tokSentencesFile, final String modelFilename) { final Tokenizer tokenizer = new Tokenizer(); final ArrayList<String> trainTokSentences = readFile(tokSentencesFile); final ArrayList<String> trainOrgSentences = readFile(orgSentencesFile); // get training data final InstanceList trainData = tokenizer.makeTrainingData(trainOrgSentences, trainTokSentences); final Pipe myPipe = trainData.getPipe(); // train a model System.out.println("training model..."); tokenizer.train(trainData, myPipe); tokenizer.writeModel(modelFilename); System.out.println("\nmodel written to: " + modelFilename); } public static void main(final String[] args) throws IOException { if (args.length < 1) { System.err.println("usage: JTBD <mode> <mode-specific-parameters>"); showModes(); System.exit(-1); } final String mode = args[0]; if (mode.equals("c")) startCheckMode(args); else if (mode.equals("s")) start9010ValidationMode(args); else if (mode.equals("x")) startXValidationMode(args); else if (mode.equals("t")) startTrainingMode(args); else if (mode.equals("p")) startPredictionMode(args); else if (mode.equals("e")) startCompareValidationMode(args); else { // unknown mode System.err.println("unknown mode"); showModes(); } } /** * reads in all lines of a file and writes each line as a string into an * arraylist the following lines are omitted: - empty lines - those * consisting of spaces only - and lines with less than 2 characters * * @param myFile * @return * @throws IOException */ static ArrayList<String> readFile(final File myFile) { final ArrayList<String> lines = new ArrayList<String>(); try { final BufferedReader b = new BufferedReader(new FileReader(myFile)); String line = ""; while ((line = b.readLine()) != null) { line = line.replaceAll("[ ]+", " "); line = line.trim(); if ((line.length() > 1) && !line.equals(" ")) // add only if line // is not empty or // does not only // consist of white // spaces or has at // least 2 // characters lines.add(line); } b.close(); } catch (final Exception e) { System.err.println("ERR: error reading file: " + myFile.toString()); e.printStackTrace(); System.exit(-1); } return lines; } /** * shows available modes */ private static void showModes() { System.err.println("\nAvailable modes:"); System.err.println("c: check data "); System.err.println("s: 90-10 split evaluation"); System.err.println("x: cross validation "); System.err.println("t: train a tokenizer "); System.err.println("p: predict with tokenizer "); System.err.println("e: evaluation on previously trained model"); System.exit(-1); } /** * Entry point for 90-10 split validation mode * * @param args * the command line arguments */ private static void start9010ValidationMode(final String[] args) { if (args.length != 5) { System.err.println("usage: JTBD s <sent-file> <tok-file> <predout-file> <errout-file>"); System.exit(-1); } final File orgSentencesFile = new File(args[1]); final File tokSentencesFile = new File(args[2]); final File predOutFile = new File(args[3]); final File errOutFile = new File(args[4]); final ArrayList<String> errors = new ArrayList<String>(); final ArrayList<String> predictions = new ArrayList<String>(); do9010Evaluation(orgSentencesFile, tokSentencesFile, predictions, errors); writeFile(predictions, predOutFile); writeFile(errors, errOutFile); } /** * Entry poing for file format check mode * * @param args * the command line arguments */ private static void startCheckMode(final String[] args) { if (args.length != 3) { System.err.println("usage: JTBD c <sent-file> <tok-file>"); System.exit(-1); } final File orgSentencesFile = new File(args[1]); final File tokSentencesFile = new File(args[2]); doCheck(orgSentencesFile, tokSentencesFile); } /** * Entry point for compare validation mode * * @param args * the command line arguments */ private static void startCompareValidationMode(final String[] args) { if (args.length != 6) { System.err.println("usage: JTBD e <modelFile> <sent-file> <tok-file> <predout-file> <errout-file>"); System.exit(-1); } ObjectInputStream in; CRF crf = null; try { // load model in = new ObjectInputStream(new GZIPInputStream(new FileInputStream(args[1]))); crf = (CRF) in.readObject(); in.close(); } catch (final Exception e) { e.printStackTrace(); } final File orgSentencesFile = new File(args[2]); final File tokSentencesFile = new File(args[3]); final ArrayList<String> orgSentences = readFile(orgSentencesFile); final ArrayList<String> tokSentences = readFile(tokSentencesFile); final File predOutFile = new File(args[4]); final File errOutFile = new File(args[5]); final ArrayList<String> errors = new ArrayList<String>(); final ArrayList<String> predictions = new ArrayList<String>(); doEvaluation(crf, orgSentences, tokSentences, predictions, errors); writeFile(predictions, predOutFile); writeFile(errors, errOutFile); } /** * Entry point for prediction mode * * @param args * the command line arguments * @throws IOException */ private static void startPredictionMode(final String[] args) throws IOException { if (args.length != 4) { System.err.println("usage: JTBD p <inDir> <outDir> <model-file>"); System.exit(-1); } final File inDir = new File(args[1]); if (!inDir.isDirectory()) { System.err.println("Error: the specified input directory does not exist."); System.exit(-1); } final File outDir = new File(args[2]); if (!outDir.isDirectory() || !outDir.canWrite()) { System.err.println("Error: the specified output directory does not exist or is not writable."); System.exit(-1); } final String modelFilename = args[3]; doPrediction(inDir, outDir, modelFilename); } /** * Entry point for training mode * * @param args * the command line arguments */ private static void startTrainingMode(final String[] args) { if (args.length != 4) { System.err.println("usage: JTBD t <sent-file> <tok-file> <model-file>"); System.exit(-1); } final File orgSentencesFile = new File(args[1]); final File tokSentencesFile = new File(args[2]); final String modelFilename = args[3]; doTraining(orgSentencesFile, tokSentencesFile, modelFilename); } /** * Entry point for cross-validation mode * * @param args * the command line arguments */ private static void startXValidationMode(final String[] args) { if (args.length != 6) { System.err.println( "usage: JTBD x <sent-file> <tok-file> <cross-val-rounds> <predout-file> <errout-file>"); System.exit(-1); } final File orgSentencesFile = new File(args[1]); final File tokSentencesFile = new File(args[2]); final int n = (new Integer(args[3])).intValue(); final File predOutFile = new File(args[4]); final File errOutFile = new File(args[5]); final ArrayList<String> errors = new ArrayList<String>(); final ArrayList<String> predictions = new ArrayList<String>(); doCrossEvaluation(n, orgSentencesFile, tokSentencesFile, predictions, errors); writeFile(predictions, predOutFile); writeFile(errors, errOutFile); } /** * writes an ArrayList of Strings to a file * * @param lines * the ArrayList * @param outFile */ static void writeFile(final ArrayList<String> lines, final File outFile) { try { final FileWriter fw = new FileWriter(outFile); for (int i = 0; i < lines.size(); i++) fw.write(lines.get(i) + "\n"); fw.close(); } catch (final Exception e) { System.err.println("ERR: error writing file: " + outFile.toString()); e.printStackTrace(); System.exit(-1); } } }