org.dkpro.tc.ml.svmhmm.task.SVMHMMTestTask.java Source code

Java tutorial

Introduction

Here is the source code for org.dkpro.tc.ml.svmhmm.task.SVMHMMTestTask.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.dkpro.tc.ml.svmhmm.task;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.SortedSet;

import org.apache.commons.collections.BidiMap;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.time.DurationFormatUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.lab.task.Discriminator;
import org.dkpro.lab.task.impl.ExecutableTaskBase;

import de.tudarmstadt.ukp.dkpro.core.api.resources.RuntimeProvider;
import org.dkpro.tc.api.exception.TextClassificationException;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.core.ml.TCMachineLearningAdapter;
import org.dkpro.tc.ml.svmhmm.SVMHMMAdapter;
import org.dkpro.tc.ml.svmhmm.util.SVMHMMUtils;

public class SVMHMMTestTask extends ExecutableTaskBase implements Constants {
    /*
     * for svm_hmm debugging purposes
     */
    public static boolean PRINT_STD_OUT = false;

    private static final String BINARIES_BASE_LOCATION = "classpath:/org/dkpro/tc/ml/svmhmm/";

    /**
     * Learning mode discriminators; Only Constants.LM_SINGLE_LABEL is allowed
     */
    @Discriminator(name = DIM_LEARNING_MODE)
    private String learningMode = LM_SINGLE_LABEL;

    @Discriminator(name = DIM_CLASSIFICATION_ARGS)
    private List<String> classificationArguments;

    private double paramC = 5;

    private double paramEpsilon = 0.5;

    private int paramOrderT = 1;

    private int paramOrderE = 0;

    private int paramB = 0;

    // where the trained model is stored
    private static final String MODEL_NAME = "svm_struct.model";

    // logger
    private static Log log = LogFactory.getLog(SVMHMMTestTask.class);

    // mapping outcome labels to integers
    protected BidiMap labelsToIntegersMapping;

    @Override
    public void execute(TaskContext taskContext) throws Exception {
        processParameters(classificationArguments);

        // where the training date are located
        File trainingDataStorage = taskContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA,
                StorageService.AccessMode.READONLY);

        // where the test data are located
        File testDataStorage = taskContext.getFolder(TEST_TASK_INPUT_KEY_TEST_DATA,
                StorageService.AccessMode.READONLY);

        // file name of the data; THE FILES HAVE SAME NAME FOR BOTH TRAINING AND
        // TESTING!!!!!!
        String fileName = new SVMHMMAdapter()
                .getFrameworkFilename(TCMachineLearningAdapter.AdapterNameEntries.featureVectorsFile);

        File trainingFile = new File(trainingDataStorage, fileName);
        File testFile = new File(testDataStorage, fileName);

        if (!Constants.LM_SINGLE_LABEL.equals(learningMode)) {
            throw new TextClassificationException(
                    learningMode + " was requested but only single label setup is supported.");
        }

        // mapping outcome labels to integers
        SortedSet<String> outcomeLabels = SVMHMMUtils.extractOutcomeLabelsFromFeatureVectorFiles(trainingFile,
                testFile);
        labelsToIntegersMapping = SVMHMMUtils.mapVocabularyToIntegers(outcomeLabels);

        // save mapping to file
        File mappingFolder = taskContext.getFolder("", StorageService.AccessMode.READWRITE);
        File mappingFile = new File(mappingFolder, SVMHMMUtils.LABELS_TO_INTEGERS_MAPPING_FILE_NAME);
        SVMHMMUtils.saveMapping(labelsToIntegersMapping, mappingFile);

        File augmentedTrainingFile = SVMHMMUtils.replaceLabelsWithIntegers(trainingFile, labelsToIntegersMapping);

        File augmentedTestFile = SVMHMMUtils.replaceLabelsWithIntegers(testFile, labelsToIntegersMapping);

        // train the model
        trainModel(taskContext, augmentedTrainingFile, paramC, paramOrderE, paramOrderT, paramEpsilon, paramB);

        // test the model
        testModel(taskContext, augmentedTestFile);
    }

    private void processParameters(List<String> classificationArguments) {
        paramC = SVMHMMUtils.getParameterC(classificationArguments);
        paramEpsilon = SVMHMMUtils.getParameterEpsilon(classificationArguments);
        paramOrderE = SVMHMMUtils.getParameterOrderE_dependencyOfEmissions(classificationArguments);
        paramOrderT = SVMHMMUtils.getParameterOrderT_dependencyOfTransitions(classificationArguments);
        paramB = SVMHMMUtils.getParameterBeamWidth(classificationArguments);
    }

    /**
     * Tests the model against the test data and stores the outcomes in the
     * {@link org.dkpro.tc.core.ml.TCMachineLearningAdapter.AdapterNameEntries#predictionsFile}
     * file.
     *
     * @param taskContext
     *            context
     * @param testFile
     *            test file
     * @throws Exception
     */
    public void testModel(TaskContext taskContext, File testFile) throws Exception {
        // file to hold prediction results
        String predictionFileName = new SVMHMMAdapter()
                .getFrameworkFilename(TCMachineLearningAdapter.AdapterNameEntries.predictionsFile);
        File predictionsFile = taskContext.getFile(predictionFileName, AccessMode.READWRITE);

        // location of the trained model
        File modelFile = taskContext.getFile(MODEL_NAME, StorageService.AccessMode.READONLY);

        callTestCommand(predictionsFile, modelFile, testFile);
    }

    public static void callTestCommand(File predictionsFile, File modelFile, File testFile) throws Exception {
        // create tmp files as workaround to long path bug in svm_hmm
        // if java temp dir is not in /tmp but in a long path dir, this won't
        // work
        File tmpTestFile = File.createTempFile("tmp_svm_hmm_test", ".txt");
        FileUtils.copyFile(testFile, tmpTestFile);

        File tmpModelFile = File.createTempFile("tmp_svm_hmm", ".model");
        FileUtils.copyFile(modelFile, tmpModelFile);

        File tmpPredictionsFile = File.createTempFile("tmp_svm_hmm_predictions", ".txt");

        // command
        List<String> testCommand = buildTestCommand(tmpTestFile, tmpModelFile.getAbsolutePath(),
                tmpPredictionsFile.getAbsolutePath());

        runCommand(testCommand);

        // copy tmp predictions back to the predictions file
        // Note: the prediction file contains the predicted labels starting from index 1 - the
        // id2outcome.txt requires a start at zero i.e. the predictions here and the one you'll find
        // in the id2outcome report are higher by +1 this is not a bug but rather an ugly necessity.
        FileUtils.copyFile(tmpPredictionsFile, predictionsFile);

        // clean up
        FileUtils.deleteQuietly(tmpTestFile);
        FileUtils.deleteQuietly(tmpPredictionsFile);
        FileUtils.deleteQuietly(tmpModelFile);
    }

    /**
     * Trains the model and stores it into the task context
     *
     * @param taskContext
     *            context
     * @throws Exception
     */
    public void trainModel(TaskContext taskContext, File trainingFile, double paramC, int paramOrderE,
            int paramOrderT, double paramEpsilon, int paramB) throws Exception {
        File tmpModel = trainModelAtTemporaryLocation(trainingFile, paramC, paramOrderE, paramOrderT, paramEpsilon,
                paramB);

        FileInputStream stream = new FileInputStream(tmpModel);
        taskContext.storeBinary(MODEL_NAME, stream);

        // clean-up
        IOUtils.closeQuietly(stream);
        FileUtils.deleteQuietly(tmpModel);
    }

    public void trainModel(File modelOutputFile, File trainingFile, double paramC, int paramOrderE, int paramOrderT,
            double paramEpsilon, int paramB) throws Exception {
        File tmpModel = trainModelAtTemporaryLocation(trainingFile, paramC, paramOrderE, paramOrderT, paramEpsilon,
                paramB);
        FileUtils.copyFile(tmpModel, modelOutputFile);
        FileUtils.deleteQuietly(tmpModel);
    }

    private File trainModelAtTemporaryLocation(File trainingFile, double paramC, int paramOrderE, int paramOrderT,
            double paramEpsilon, int paramB) throws Exception {
        File tmpModelFile = File.createTempFile("tmp_svm_hmm", ".model");

        // we have to copy the training file to tmp to prevent long path issue
        // in svm_hmm
        File tmpTrainingFile = File.createTempFile("tmp_svm_hmm_training", ".txt");
        FileUtils.copyFile(trainingFile, tmpTrainingFile);

        List<String> modelTrainCommand = buildTrainCommand(tmpTrainingFile, tmpModelFile.getPath(), paramC,
                paramOrderE, paramOrderT, paramEpsilon, paramB);

        LogFactory.getLog(getClass()).info("Use following parametrization: " + toString(modelTrainCommand));

        log.debug("Start training model");
        long time = System.currentTimeMillis();
        runCommand(modelTrainCommand);
        long completedIn = System.currentTimeMillis() - time;
        String formattedDuration = DurationFormatUtils.formatDuration(completedIn, "HH:mm:ss:SS");
        log.info("Training finished after " + formattedDuration);

        FileUtils.deleteQuietly(tmpTrainingFile);

        return tmpModelFile;
    }

    private String toString(List<String> modelTrainCommand) {
        StringBuilder sb = new StringBuilder();
        for (String s : modelTrainCommand) {
            sb.append(s + " ");
        }
        return sb.toString().trim();
    }

    /**
     * Builds command line parameters for testing predictions
     *
     * @param testFile
     *            test file
     * @param modelLocation
     *            trained model file
     * @param outputPredictions
     *            where the results will be stored
     * @return command as a list of Strings
     */
    private static List<String> buildTestCommand(File testFile, String modelLocation, String outputPredictions)
            throws IOException {
        List<String> result = new ArrayList<>();

        result.add(resolveSVMHmmClassifyCommand());
        result.add(testFile.getAbsolutePath());
        result.add(modelLocation);
        result.add(outputPredictions);

        return result;
    }

    /**
     * Returns absolute path to svm_hmm_classify binary
     *
     * @return binary path
     */
    public static String resolveSVMHmmClassifyCommand() {
        try {
            return new RuntimeProvider(BINARIES_BASE_LOCATION).getFile("svm_hmm_classify").getAbsolutePath();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Returns absolute path to svm_hmm_learn binary
     *
     * @return binary path
     */
    public static String resolveSVMHmmLearnCommand() {
        try {
            return new RuntimeProvider(BINARIES_BASE_LOCATION).getFile("svm_hmm_learn").getAbsolutePath();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Builds the cmd with parameters to run svm_hmm_train
     *
     * @param trainingFile
     *            training file
     * @param targetModelLocation
     *            where the trained model will be stored
     * @return command as list of Strings
     */
    private List<String> buildTrainCommand(File trainingFile, String targetModelLocation, double paramC,
            int paramOrderE, int paramOrderT, double paramEpsilon, int paramB) {
        List<String> result = new ArrayList<>();
        result.add(resolveSVMHmmLearnCommand());

        // svm struct params
        result.add("-c");
        result.add(String.format(Locale.ENGLISH, "%f", this.paramC));
        result.add("--e");
        result.add(Integer.toString(this.paramOrderE));
        result.add("--t");
        result.add(Integer.toString(this.paramOrderT));
        result.add("-e");
        result.add(String.format(Locale.ENGLISH, "%f", this.paramEpsilon));
        result.add("--b");
        result.add(Integer.toString(this.paramB));

        // training file
        result.add(trainingFile.getAbsolutePath());

        // output model
        result.add(targetModelLocation);

        return result;
    }

    /**
     * Executes the command (runs a new process outside the JVM and waits for its completion)
     *
     * @param command
     *            command as list of Strings
     */
    private static void runCommand(List<String> command) throws Exception {
        log.info(StringUtils.join(command, " "));

        if (PRINT_STD_OUT) {
            Process process = new ProcessBuilder().inheritIO().command(command).start();
            process.waitFor();
        } else {
            // create temp files for capturing output instead of printing to
            // stdout/stderr
            File tmpOutLog = File.createTempFile("tmp.out.", ".log");

            ProcessBuilder processBuilder = new ProcessBuilder(command);

            processBuilder.redirectError(tmpOutLog);
            processBuilder.redirectOutput(tmpOutLog);

            // run the process
            Process process = processBuilder.start();
            process.waitFor();

            // re-read the output and debug it
            BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(tmpOutLog)));
            String line;
            while ((line = br.readLine()) != null) {
                log.debug(line);
            }
            IOUtils.closeQuietly(br);

            // delete files
            FileUtils.deleteQuietly(tmpOutLog);
        }
    }
}