Java tutorial
/** * Copyright 2015 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universitt Darmstadt * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package de.tudarmstadt.ukp.dkpro.tc.crfsuite.task; import java.io.File; import java.io.FileInputStream; import java.io.InputStream; import java.util.ArrayList; import java.util.List; import java.util.Scanner; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.time.DurationFormatUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import de.tudarmstadt.ukp.dkpro.core.api.resources.PlatformDetector; import de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils; import de.tudarmstadt.ukp.dkpro.core.api.resources.RuntimeProvider; import de.tudarmstadt.ukp.dkpro.lab.engine.TaskContext; import de.tudarmstadt.ukp.dkpro.lab.storage.StorageService.AccessMode; import de.tudarmstadt.ukp.dkpro.lab.task.Discriminator; import de.tudarmstadt.ukp.dkpro.lab.task.impl.ExecutableTaskBase; import de.tudarmstadt.ukp.dkpro.tc.api.exception.TextClassificationException; import de.tudarmstadt.ukp.dkpro.tc.core.Constants; import de.tudarmstadt.ukp.dkpro.tc.core.ml.TCMachineLearningAdapter.AdapterNameEntries; import de.tudarmstadt.ukp.dkpro.tc.core.util.ReportConstants; import de.tudarmstadt.ukp.dkpro.tc.crfsuite.CRFSuiteAdapter; import de.tudarmstadt.ukp.dkpro.tc.crfsuite.writer.LabelSubstitutor; public class CRFSuiteTestTask extends ExecutableTaskBase implements Constants { @Discriminator private String learningMode; @Discriminator private String[] classificationArguments; public static final String MODELNAME = "model.crfsuite"; public static final String FILE_PER_CLASS_PRECISION_RECALL_F1 = "precisionRecallF1PerWordClass.txt"; Log logger = null; private String executablePath = null; private String modelLocation = null; private File trainFile = null; private File testFile = null; private static RuntimeProvider runtimeProvider = null; @Override public void execute(TaskContext aContext) throws Exception { boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL); if (multiLabel) { throw new TextClassificationException( "Multi-label requested, but CRFSuite only supports single label setups."); } sanityCheckOnClassificationArguments(); executablePath = getExecutablePath(); modelLocation = trainModel(aContext); String rawTextOutput = testModel(aContext); writePredictions2File(aContext, rawTextOutput); } private void sanityCheckOnClassificationArguments() throws Exception { if (classificationArguments == null || classificationArguments.length == 0) { log("No algorithm has been provided - will use CRFsuite default (lbfgs)"); return; } if (classificationArguments.length == 1) { return; } /* * At the moment only a pair of parameters is expected for provide the algorithm CRFsuite * uses */ throw new Exception("Unexpected amount of classification arguments: " + "[" + classificationArguments.length + "] expected either zero or one"); } public static String getExecutablePath() throws Exception { if (runtimeProvider == null) { PlatformDetector pd = new PlatformDetector(); String platform = pd.getPlatformId(); LogFactory.getLog(CRFSuiteTestTask.class.getName()) .info("Load binary for platform: [" + platform + "]"); runtimeProvider = new RuntimeProvider("classpath:/de/tudarmstadt/ukp/dkpro/tc/crfsuite/"); } String executablePath = runtimeProvider.getFile("crfsuite").getAbsolutePath(); LogFactory.getLog(CRFSuiteTestTask.class.getName()).info("Will use binary: [" + executablePath + "]"); return executablePath; } private void writePredictions2File(TaskContext aContext, String aRawTextOutput) throws Exception { writeCRFSuiteGeneratedReports2File(aContext); List<String> predictionValues = writeSelfGeneratedAccuracyReport2File(aContext, aRawTextOutput); writeFileWithPredictedLabels(aContext, predictionValues); } private void writeFileWithPredictedLabels(TaskContext aContext, List<String> predictionValues) throws Exception { File predictionsFile = new File(aContext.getStorageLocation(TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE), CRFSuiteAdapter.getInstance().getFrameworkFilename(AdapterNameEntries.predictionsFile)); StringBuilder sb = new StringBuilder(); sb.append("#Gold\tPrediction\n"); for (String p : predictionValues) { sb.append(LabelSubstitutor.undoLabelReplacement(p) + "\n"); // NOTE: CRFSuite has a bug when the label is ':' (as in // PennTreeBank Part-of-speech tagset for instance) // We perform a substitutions to something crfsuite can handle // correctly, see class // LabelSubstitutor for more details } FileUtils.writeStringToFile(predictionsFile, sb.toString()); } private List<String> writeSelfGeneratedAccuracyReport2File(TaskContext aContext, String aRawTextOutput) throws Exception { String[] lines = aRawTextOutput.split("\n"); int correct = 0; int incorrect = 0; List<String> predictionValues = new ArrayList<String>(); for (String line : lines) { predictionValues.add(line); String[] split = line.split("\t"); if (split.length < 2) { continue; } String actual = split[0]; String prediction = split[1]; if (actual.equals(prediction)) { correct++; } else { incorrect++; } } double denominator = correct + incorrect; double numerator = correct; double accuracy = 0; if (denominator > 0) { accuracy = numerator / denominator; } log("Accuracy: " + accuracy * 100 + " (" + correct + " correct, " + incorrect + " incorrect)"); // file to hold prediction results File evalFile = new File(aContext.getStorageLocation(TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE), CRFSuiteAdapter.getInstance().getFrameworkFilename(AdapterNameEntries.evaluationFile)); StringBuilder sb = new StringBuilder(); sb.append(ReportConstants.CORRECT + "=" + correct + "\n"); sb.append(ReportConstants.INCORRECT + "=" + incorrect + "\n"); sb.append(ReportConstants.PCT_CORRECT + "=" + accuracy + "\n"); FileUtils.writeStringToFile(evalFile, sb.toString()); return predictionValues; } private void writeCRFSuiteGeneratedReports2File(TaskContext aContext) throws Exception { String precRecF1perClass = getPrecisionRecallF1PerClass(); log(precRecF1perClass); File precRecF1File = new File(aContext.getStorageLocation(TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE), FILE_PER_CLASS_PRECISION_RECALL_F1); FileUtils.write(precRecF1File, "\n" + precRecF1perClass); } private String getPrecisionRecallF1PerClass() throws Exception { String executablePath = getExecutablePath(); List<String> evalCommand = new ArrayList<String>(); evalCommand.add(executablePath); evalCommand.add("tag"); evalCommand.add("-qt"); evalCommand.add("-m"); evalCommand.add(modelLocation); evalCommand.add(testFile.getAbsolutePath()); Process process = new ProcessBuilder().command(evalCommand).start(); String output = captureProcessOutput(process); return output; } private String testModel(TaskContext aContext) throws Exception { List<String> testModelCommand = buildTestCommand(aContext); log("Testing model"); String output = runTest(testModelCommand); log("Testing model finished"); return output; } public static String runTest(List<String> aTestModelCommand) throws Exception { Process process = new ProcessBuilder().command(aTestModelCommand).start(); String output = captureProcessOutput(process); return output; } private static String captureProcessOutput(Process aProcess) { InputStream src = aProcess.getInputStream(); Scanner sc = new Scanner(src); StringBuilder dest = new StringBuilder(); while (sc.hasNextLine()) { String l = sc.nextLine(); dest.append(l + "\n"); } sc.close(); return dest.toString(); } private List<String> buildTestCommand(TaskContext aContext) throws Exception { File tmpTest = new File(aContext.getStorageLocation(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY) .getPath() + "/" + CRFSuiteAdapter.getInstance().getFrameworkFilename(AdapterNameEntries.featureVectorsFile)); testFile = ResourceUtils.getUrlAsFile(tmpTest.toURI().toURL(), true); return wrapTestCommandAsList(testFile, executablePath, modelLocation); } public static List<String> wrapTestCommandAsList(File aTestFile, String aExecutablePath, String aModelLocation) { List<String> commandTestModel = new ArrayList<String>(); commandTestModel.add(aExecutablePath); commandTestModel.add("tag"); commandTestModel.add("-r"); commandTestModel.add("-m"); commandTestModel.add(aModelLocation); commandTestModel.add(aTestFile.getAbsolutePath()); return commandTestModel; } private String trainModel(TaskContext aContext) throws Exception { String tmpModelLocation = System.getProperty("java.io.tmpdir") + File.separator + MODELNAME; List<String> modelTrainCommand = buildTrainCommand(aContext, tmpModelLocation); log("Start training model"); long time = System.currentTimeMillis(); runTrain(modelTrainCommand); long completedIn = System.currentTimeMillis() - time; String formattedDuration = DurationFormatUtils.formatDuration(completedIn, "HH:mm:ss:SS"); log("Training finished after " + formattedDuration); return writeModel(aContext, tmpModelLocation); } private void runTrain(List<String> aModelTrainCommand) throws Exception { Process process = new ProcessBuilder().inheritIO().command(aModelTrainCommand).start(); process.waitFor(); } private String writeModel(TaskContext aContext, String aTmpModelLocation) throws Exception { aContext.storeBinary(MODELNAME, new FileInputStream(new File(aTmpModelLocation))); File modelLocation = aContext.getStorageLocation(MODELNAME, AccessMode.READONLY); return modelLocation.getAbsolutePath(); } private List<String> buildTrainCommand(TaskContext aContext, String aTmpModelLocation) throws Exception { File tmpTrain = new File(aContext.getStorageLocation(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY) .getPath() + "/" + CRFSuiteAdapter.getInstance().getFrameworkFilename(AdapterNameEntries.featureVectorsFile)); trainFile = ResourceUtils.getUrlAsFile(tmpTrain.toURI().toURL(), true); return getTrainCommand(aTmpModelLocation, trainFile.getAbsolutePath(), classificationArguments != null ? classificationArguments[0] : null); } public static List<String> getTrainCommand(String modelOutputLocation, String trainingFile, String algorithm) throws Exception { List<String> commandTrainModel = new ArrayList<String>(); commandTrainModel.add(getExecutablePath()); commandTrainModel.add("learn"); commandTrainModel.add("-m"); commandTrainModel.add(modelOutputLocation); // add algorithm if provided if (algorithm != null) { commandTrainModel.add("-a"); commandTrainModel.add(algorithm); } commandTrainModel.add(trainingFile); return commandTrainModel; } private void log(String text) { if (logger == null) { logger = LogFactory.getLog(getClass()); } logger.info(text); } }