qa.experiment.SRLDSCombinedModelCrossValidation.java Source code

Java tutorial

Introduction

Here is the source code for qa.experiment.SRLDSCombinedModelCrossValidation.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package qa.experiment;

import Util.ClearParserUtil;
import Util.ProcessFrameUtil;
import Util.StdUtil;
import Util.StringUtil;
import clear.engine.SRLPredict;
import clear.engine.SRLTrain;
import clear.util.FileUtil;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.io.FileUtils;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import qa.ProcessFrame;
import qa.ProcessFrameProcessor;
import qa.srl.SRLEvaluate;
import qa.srl.SRLWrapper;

/**
 *
 * @author samuellouvan
 */
public class SRLDSCombinedModelCrossValidation {

    ProcessFrameProcessor proc;

    private ArrayList<String> blackList;
    @Option(name = "-f", usage = "process file", required = true, metaVar = "REQUIRED")
    private String processTsvFileName;

    @Option(name = "-o", usage = "output directory name", required = true, metaVar = "REQUIRED")
    private String outDirName;

    @Option(name = "-d", usage = "directory where the ds files located", required = true, metaVar = "REQUIRED")
    private String dsDirName;

    @Option(name = "-srl", usage = "SRL type", required = true, metaVar = "OPTIONAL")
    private int srlType;

    @Option(name = "-df", usage = "ds file name", required = false, metaVar = "OPTIONAL")
    private String dsFileName = "ds_all_processes_w_pattern.tsv";

    @Option(name = "-k", usage = "number of fold", required = true, metaVar = "REQUIRED")
    private int fold;

    @Option(name = "-n", usage = "number of processes to test", required = false, metaVar = "OPTIONAL")
    private int nbProcess = 0;

    @Option(name = "-p", usage = "Process to test", required = false, metaVar = "OPTIONAL")
    private String processToTest = "";

    @Option(name = "-pi", usage = "predicate/trigger identification", required = false, metaVar = "OPTIONAL")
    private boolean pi = false;

    boolean limitedProcess = false;
    ArrayList<ProcessFrame> inverseData;
    private ArrayList<ProcessFrame> frameArr;
    private HashMap<String, Integer> processFold;
    ProcessFrameProcessor dsProc;
    private ArrayList<String> processNames;
    ArrayList<String> testFilePath;
    ArrayList<String> trainingModelFilePath;
    /* String[] blackListProcess = {"Salivating", "composted", "decant_decanting", "dripping", "magneticseparation", "loosening", "momentum", "seafloorspreadingtheory", "sedimentation",
     "spear_spearing", "retract", "distillation", "Feelsleepy", "filtering", "revising" "fertilization",
     "freeze_freezing", "germinating_germination", "inferring", "melt_melting", "reusing", "takeinnutrients_takinginnutrients", "sight",
     "upwelling", "write", "work", "vibrates_vibration_vibrations", "warming", "watercycle_thewatercycle", "weather_weathering", "whiten_becomewhiter", "windbreaking"};*/

    String[] blackListProcess = { "Salivating", "composted", "decant_decanting", "dripping", "magneticseparation",
            "loosening", "momentum", "seafloorspreadingtheory", "sedimentation", "spear_spearing", "retract",
            "drop_dropping", "Feelsleepy", "harden", "positivetropism", "Resting", "separated", "revising",
            "sight" };

    public SRLDSCombinedModelCrossValidation() throws FileNotFoundException {
        trainingModelFilePath = new ArrayList<String>();
        testFilePath = new ArrayList<String>();
        processFold = new HashMap<String, Integer>();
        processNames = new ArrayList<String>();
        blackList = new ArrayList<String>();
        frameArr = new ArrayList<ProcessFrame>();

    }

    public void init() throws FileNotFoundException, IOException, ClassNotFoundException {
        proc = new ProcessFrameProcessor(processTsvFileName);
        proc.loadProcessData();
        blackList = new ArrayList(Arrays.asList(blackListProcess));
        if (nbProcess > 0) {
            limitedProcess = true;
            for (int i = 0; i < proc.getProcArr().size() && processNames.size() < nbProcess; i++) {
                String normProcessName = ProcessFrameUtil
                        .normalizeProcessName(proc.getProcArr().get(i).getProcessName());
                if (!processNames.contains(normProcessName)) {
                    processNames.add(normProcessName);
                }
            }
        }
        if (processToTest.isEmpty()) {
            frameArr = proc.getProcArr();
            for (int i = 0; i < frameArr.size(); i++) {
                String normName = ProcessFrameUtil.normalizeProcessName(frameArr.get(i).getProcessName());
                if (!processNames.contains(normName)) {
                    processNames.add(normName);
                }
                processFold.put(normName, 0);
            }
        } else {
            for (int i = 0; i < proc.getProcArr().size(); i++) {
                String normName = ProcessFrameUtil.normalizeProcessName(proc.getProcArr().get(i).getProcessName());
                String[] normNames = normName.split("_");
                if (StringUtil.contains(processToTest, normNames)) {
                    frameArr.add(proc.getProcArr().get(i));
                    if (!processNames.contains(normName)) {
                        processNames.add(normName);
                    }
                    processFold.put(normName, 0);
                }
            }
            if (processFold.size() == 0) {
                System.out.println("Cannot find the process to test!");
                System.exit(0);
            }
        }
        File outDirHandler = new File(outDirName);
        if (outDirHandler.exists()) {
            //FileUtils.cleanDirectory(outDirHandler);
        } else {
            boolean success = outDirHandler.mkdir();
            if (!success) {
                System.out.println("FAILED to create output directory");
                System.exit(0);
            }
        }
    }

    public void doTrain(String trainingFileName, String modelFileName) throws IOException, FileNotFoundException,
            NoSuchMethodException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
        new SRLWrapper().doTrain(trainingFileName, modelFileName, srlType, false);
    }

    public void doPredict() throws NoSuchMethodException, IllegalAccessException, IllegalArgumentException,
            InvocationTargetException {
        for (int i = 0; i < testFilePath.size(); i++) {
            new SRLWrapper().doPredict(testFilePath.get(i),
                    testFilePath.get(i).replace("test.", "dscombined.predict."), trainingModelFilePath.get(i),
                    srlType, pi, false);
        }
    }

    public void trainAndPredict()
            throws FileNotFoundException, IOException, InterruptedException, ClassNotFoundException,
            NoSuchMethodException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {

        testFilePath.clear();
        trainingModelFilePath.clear();
        dsProc = new ProcessFrameProcessor(dsDirName + "/" + dsFileName);
        dsProc.loadProcessData();
        int cnt = 0;
        for (String processName : processNames) {

            if (!blackList.contains(processName)) {
                ArrayList<ProcessFrame> processData = proc.getProcessFrameByNormalizedName(processName);
                Collections.shuffle(processData);
                inverseData = proc.getInverseProcessFrameByNormalizedName(processName);
                if (processData.size() < 5) // Special case
                {
                    System.out.println("Less than 5 " + processName + " size " + processData.size());
                    doCrossValidation(processName, processData, processData.size());
                } else {
                    doCrossValidation(processName, processData, fold);
                }
                cnt++;
            }
        }
        //doPredict();
    }

    public void doCrossValidation(String processName, ArrayList<ProcessFrame> selectedProcessFrame, int foldSize)
            throws IOException, InterruptedException, FileNotFoundException, NoSuchMethodException,
            IllegalAccessException, IllegalArgumentException, InvocationTargetException {
        int startIdx = 0;
        int testSize = selectedProcessFrame.size() / foldSize;
        int endIdx = testSize;
        for (int currentFold = 0; currentFold < foldSize; currentFold++) {
            ArrayList<ProcessFrame> testingFrames = new ArrayList<ProcessFrame>(
                    selectedProcessFrame.subList(startIdx, endIdx));
            ArrayList<ProcessFrame> trainingFrames = new ArrayList<ProcessFrame>(
                    selectedProcessFrame.subList(0, startIdx));
            trainingFrames.addAll(
                    new ArrayList<ProcessFrame>(selectedProcessFrame.subList(endIdx, selectedProcessFrame.size())));
            trainingFrames.addAll(inverseData);
            trainingFrames.addAll(dsProc.getProcArr());

            String trainingFileName = outDirName + "/" + processName + ".train.dscombined.cv." + currentFold;
            String testingFileName = outDirName + "/" + processName + ".test.cv." + currentFold;
            String modelName = outDirName + "/" + processName + ".dscombinedmodel.cv." + currentFold;

            testFilePath.add(testingFileName);
            trainingModelFilePath.add(modelName);
            System.out.println(testingFileName);
            boolean train = false;
            if (!(new File(trainingFileName).exists())) {
                ProcessFrameUtil.toParserFormat(trainingFrames, trainingFileName, srlType);
                train = true;
            }
            if (!(new File(testingFileName).exists())) {
                ProcessFrameUtil.toParserFormat(testingFrames, testingFileName, srlType);
            }

            if (train) {
                doTrain(trainingFileName, modelName);
            }
            startIdx = endIdx;
            if (currentFold == foldSize - 2) {
                endIdx = selectedProcessFrame.size();
            } else {
                endIdx = startIdx + testSize;
            }
        }
    }

    /**
     * Compute the precision, recall, F1 of the predictions by executing
     * combine.py and evaluate.py
     */
    public void evaluate() throws FileNotFoundException, IOException {
        new SRLEvaluate().evaluateOverall(testFilePath, "test.", "dscombined.predict.", srlType);
    }

    public static void main(String[] args) throws FileNotFoundException {
        SRLDSCombinedModelCrossValidation srlExp = new SRLDSCombinedModelCrossValidation();
        CmdLineParser cmd = new CmdLineParser(srlExp);

        try {
            cmd.parseArgument(args);
            srlExp.init();
            srlExp.trainAndPredict();
            Thread.sleep(5000);
            srlExp.evaluate();
            System.out.println("FINISH");
        } catch (CmdLineException e) {
            System.err.println(e.getMessage());
            cmd.printUsage(System.err);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}