sbu.srl.rolextract.SBURolePredict.java Source code

Java tutorial

Introduction

Here is the source code for sbu.srl.rolextract.SBURolePredict.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package sbu.srl.rolextract;

import Util.ArgProcessAnnotationDataUtil;
import Util.Constant;
import Util.GlobalV;
import Util.LibSVMUtil;
import Util.ProcessFrameUtil;
import Util.SentenceUtil;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import edu.uw.easysrl.main.Argument;
import edu.uw.easysrl.main.ParseResult;
import edu.uw.easysrl.main.Predicate;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static java.util.stream.Collectors.toList;

import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import org.apache.commons.lang3.StringUtils;
import qa.ProcessFrame;
import qa.ProcessFrameProcessor;
import qa.StanfordDepParserSingleton;
import qa.StanfordTokenizerSingleton;
import qa.dep.DependencyNode;
import qa.dep.DependencyTree;
import qa.srl.SRLWrapper;
import qa.util.FileUtil;
import sbu.srl.datastructure.ArgProcessAnnotationData;
import sbu.srl.datastructure.ArgumentSpan;
import sbu.srl.datastructure.Sentence;
import sbu.srl.ml.LibLinearWrapper;

/**
 * INPUT : Process Frame (already preprocessed)
 *
 * @author slouvan
 */
public class SBURolePredict {

    SpockDataReader dataReader;
    Set<String> classLabels = new HashSet<String>();
    HashMap<String, FeatureExtractor> fExtractors;
    HashMap<String, Model> models;
    String[] annotations;
    String configFileName;
    String predictionFileName;
    private boolean isMulticlass = false;
    boolean knownAnnotation = true;

    public SBURolePredict(String testFileName, String configFileName, String modelDir, String predictionFileName,
            boolean isMultiClass) throws IOException, FileNotFoundException, ClassNotFoundException {
        dataReader = new SpockDataReader(testFileName, configFileName, false);
        dataReader.readProcessData();
        this.configFileName = configFileName;
        this.predictionFileName = predictionFileName;
        this.isMulticlass = isMultiClass;
        classLabels = dataReader.getRoleLabels();
        fExtractors = new HashMap<String, FeatureExtractor>();
        models = new HashMap<String, Model>();

        if (isMultiClass) {
            classLabels.add("NONE");
            String roleName = "Multi";
            if (FileUtil.isFileExist(modelDir + "/" + roleName + ".featureExtract")) {
                // Load feature extractor
                fExtractors.put(roleName, (FeatureExtractor) FileUtil
                        .deserializeFromFile(modelDir + "/" + roleName + ".featureExtract"));
                // Load model
                Model model = Linear.loadModel(new FileReader(modelDir + "/" + roleName + ".model"));
                //System.out.println(roleName);
                models.put(roleName, model);
            }
        } else {
            for (String roleName : classLabels) {
                if (FileUtil.isFileExist(modelDir + "/" + roleName + ".featureExtract")) {
                    // Load feature extractor
                    fExtractors.put(roleName, (FeatureExtractor) FileUtil
                            .deserializeFromFile(modelDir + "/" + roleName + ".featureExtract"));
                    // Load model
                    Model model = Linear.loadModel(new FileReader(modelDir + "/" + roleName + ".model"));
                    //System.out.println(roleName);
                    models.put(roleName, model);
                }
            }
        }

    }

    public SBURolePredict(String modelDir, String serFile, boolean isMultiClass)
            throws IOException, FileNotFoundException, ClassNotFoundException {
        ArrayList<Sentence> sentences = (ArrayList<Sentence>) FileUtil.deserializeFromFile(serFile);
        dataReader = new SpockDataReader();
        dataReader.sentences = sentences;
        classLabels = dataReader.getRoleLabels();
        fExtractors = new HashMap<String, FeatureExtractor>();
        models = new HashMap<String, Model>();
        this.isMulticlass = isMultiClass;
        if (isMultiClass) {
            String roleName = "Multi";
            if (FileUtil.isFileExist(modelDir + "/" + roleName + ".featureExtract")) {
                // Load feature extractor
                fExtractors.put(roleName, (FeatureExtractor) FileUtil
                        .deserializeFromFile(modelDir + "/" + roleName + ".featureExtract"));
                // Load model
                Model model = Linear.loadModel(new FileReader(modelDir + "/" + roleName + ".model"));
                // System.out.println(roleName);
                models.put(roleName, model);
            }
        } else {
            for (String roleName : classLabels) {
                if (FileUtil.isFileExist(modelDir + "/" + roleName + ".featureExtract")) {
                    // Load feature extractor
                    fExtractors.put(roleName, (FeatureExtractor) FileUtil
                            .deserializeFromFile(modelDir + "/" + roleName + ".featureExtract"));
                    // Load model
                    Model model = Linear.loadModel(new FileReader(modelDir + "/" + roleName + ".model"));
                    // System.out.println(roleName);
                    models.put(roleName, model);
                }
            }
        }

    }

    public static boolean isOverlapping(String text, String otherText) {
        List<String> tokenizedText = StanfordTokenizerSingleton.getInstance().tokenize(text);
        List<String> tokenizedOtherText = StanfordTokenizerSingleton.getInstance().tokenize(otherText);
        for (String strText : tokenizedText) {
            if (tokenizedOtherText.contains(strText)) {
                return true;
            }
        }
        return false;
    }

    public static Object getBestArgument(ParseResult parseResult, String targetText) {
        List<Predicate> predicates = parseResult.getPredicates();
        List<Argument> arguments = new ArrayList<Argument>();

        for (Predicate predicate : predicates) {
            arguments.addAll(predicate.getArguments());
        }
        ArrayList<Object> overlappedSpans = new ArrayList<>();
        int minimumDistance = Integer.MAX_VALUE;
        boolean overlapping = false;
        for (Predicate predicate : predicates) {
            if (isOverlapping(predicate.getText(), targetText)) {
                minimumDistance = Math.min(minimumDistance,
                        StringUtils.getLevenshteinDistance(targetText, predicate.getText()));
                overlapping = true;
            }
        }
        for (Argument argument : arguments) {
            if (isOverlapping(targetText, argument.getText())) {
                minimumDistance = Math.min(minimumDistance,
                        StringUtils.getLevenshteinDistance(targetText, argument.getText()));
                overlapping = true;
            }
        }
        if (!overlapping) {
            return null; // NONE
        }
        for (Predicate predicate : predicates) {
            if (StringUtils.getLevenshteinDistance(targetText, predicate.getText()) == minimumDistance) {
                overlappedSpans.add(predicate);
            }
        }
        for (Argument argument : arguments) {
            if (StringUtils.getLevenshteinDistance(targetText, argument.getText()) == minimumDistance) {
                overlappedSpans.add(argument);
            }
        }
        if (overlappedSpans.size() > 1) {
            double maxScore = Double.MIN_VALUE;
            Object bestSpan = null;
            for (Object obj : overlappedSpans) {
                if (obj instanceof Predicate) {
                    if (((Predicate) obj).getScore() > maxScore) {
                        maxScore = ((Predicate) obj).getScore();
                        bestSpan = ((Predicate) obj);
                    }
                } else {
                    if (((Argument) obj).getArgScore() > maxScore) {
                        maxScore = ((Argument) obj).getArgScore();
                        bestSpan = ((Argument) obj);
                    }
                }
            }
            return bestSpan;
        } else {
            return overlappedSpans.get(0);
        }
    }

    public static void performPredictionEasySRL(String testObjFile, String testSentenceListFile,
            String outputFileName, String modelFileName, String foldDir)
            throws IOException, FileNotFoundException, ClassNotFoundException, NoSuchMethodException,
            IllegalAccessException, IllegalArgumentException, InvocationTargetException {
        ArrayList<Sentence> sentences = (ArrayList<Sentence>) FileUtil.deserializeFromFile(testObjFile);
        new SRLWrapper().doPredictProcessRoleCCG(testSentenceListFile, outputFileName, modelFileName,
                foldDir.concat("/test/easySrlOut"), Constant.SRL_CCG, true, false);
        for (int i = 0; i < sentences.size(); i++) {
            System.out.println("SENTENCE : " + i + sentences.get(i).getRawText());
            Sentence currentSentence = sentences.get(i);
            ArrayList<ArgumentSpan> spans = currentSentence.getAllAnnotatedArgumentSpan();
            HashMap<String, String> argumentSpanThatHasAnnotation = currentSentence
                    .getAllArgumentsThatHaveAnnotation();
            spans = (ArrayList<ArgumentSpan>) spans.stream().distinct().collect(toList());
            //if (knownAnnotation) {
            spans = (ArrayList<ArgumentSpan>) spans.stream()
                    .filter(s -> argumentSpanThatHasAnnotation
                            .get(currentSentence.getId() + "_" + s.getStartIdx() + "_" + s.getEndIdx()) != null)
                    .collect(toList());
            //}
            HashMap<String, ArrayList<ArgumentSpan>> roleArgPrediction = new HashMap<String, ArrayList<ArgumentSpan>>();
            if (spans.size() == 0) {
                continue;
            }

            List<ParseResult> parseResult = SentenceUtil.readEasySRLJSONdata(foldDir.concat("/test/easySrlOut"));
            for (int j = 0; j < spans.size(); j++) {
                HashMap<String, Double> roleProbPair = new HashMap<String, Double>();
                HashMap<String, String> roleVectorPair = new HashMap<String, String>();
                ArgumentSpan currentSpan = spans.get(j);
                String text = currentSpan.getText();
                ParseResult sentParseResult = parseResult.get(i);
                int x = 0;

                if (sentParseResult.getParseScore() == -1.0) {
                    //currentSpan.setRoleProbPair(roleProbPair);
                    currentSpan.predictRoleType(true);
                    currentSpan.setRolePredicted("NONE");
                    roleProbPair.put("NONE", 1.0);
                    currentSpan.setRoleProbPair(roleProbPair);
                } else {
                    Object bestOverlap = getBestArgument(sentParseResult, text);
                    if (bestOverlap != null) {
                        //currentSpan.setRoleProbPair(roleProbPair);
                        String rolePredicted = "";
                        if (bestOverlap instanceof Predicate) {
                            rolePredicted = "trigger";
                            roleProbPair.put(rolePredicted, ((Predicate) bestOverlap).getScore());
                        } else {
                            try {
                                System.out.println(((Argument) bestOverlap).getLabel());
                                String label = ((Argument) bestOverlap).getLabel();

                                if (label.equalsIgnoreCase("ARG0") || label.equalsIgnoreCase("CAU")) {
                                    rolePredicted = "enabler";
                                } else if (label.equalsIgnoreCase("ARG1")) {
                                    rolePredicted = "undergoer";
                                } else if (label.equalsIgnoreCase("ARG2") || label.equalsIgnoreCase("PNC")) {
                                    rolePredicted = "result";
                                } else {
                                    rolePredicted = "NONE";
                                }
                            } catch (Exception e) {
                                e.printStackTrace();
                                System.out.println(bestOverlap.getClass().toString());
                            }
                            roleProbPair.put(rolePredicted, ((Argument) bestOverlap).getArgScore());
                        }
                        currentSpan.predictRoleType(true);
                        currentSpan.setRolePredicted(rolePredicted);
                        currentSpan.setRoleProbPair(roleProbPair);
                    } else {
                        currentSpan.predictRoleType(true);
                        currentSpan.setRolePredicted("NONE");
                        roleProbPair.put("NONE", 1.0);
                        currentSpan.setRoleProbPair(roleProbPair);
                    }

                }

                if (roleArgPrediction.get(currentSpan.getRolePredicted()) != null) {
                    ArrayList<ArgumentSpan> predictedSpan = roleArgPrediction.get(currentSpan.getRolePredicted());
                    predictedSpan.add(currentSpan);
                    roleArgPrediction.put(currentSpan.getRolePredicted(), predictedSpan);
                } else {
                    ArrayList<ArgumentSpan> predictedSpan = new ArrayList<ArgumentSpan>();
                    predictedSpan.add(currentSpan);
                    roleArgPrediction.put(currentSpan.getRolePredicted(), predictedSpan);
                }
            }
            if (roleArgPrediction == null) {
                System.out.println("Something is going wrong here");
                System.exit(0);
            }
            currentSentence.setRoleArgPrediction(roleArgPrediction);
        }

        // populateProbabilityILP(procDataAnnArr);
        // make unique of the SAME arguments,  based on startID and endID
        if (testObjFile.contains("gold")) {
            FileUtil.serializeToFile(sentences, testObjFile.replace("gold", "easysrlpredict"));
        } else {

            FileUtil.serializeToFile(sentences, testObjFile.replace("test.", "easysrlpredict."));
        }
    }

    // ONLY ANNOTATED OR NOT
    public void performPrediction(String testingFileName)
            throws IOException, FileNotFoundException, ClassNotFoundException {
        ArrayList<Sentence> sentences = (ArrayList<Sentence>) FileUtil.deserializeFromFile(testingFileName);
        for (int i = 0; i < sentences.size(); i++) {
            Sentence currentSentence = sentences.get(i);
            System.out.println("[Prediction] " + i + " / " + sentences.size());
            ArrayList<ArgumentSpan> spans = currentSentence.getAllAnnotatedArgumentSpan();
            HashMap<String, String> argumentSpanThatHasAnnotation = currentSentence
                    .getAllArgumentsThatHaveAnnotation();
            spans = (ArrayList<ArgumentSpan>) spans.stream().distinct().collect(toList());
            if (knownAnnotation) {
                spans = (ArrayList<ArgumentSpan>) spans.stream()
                        .filter(s -> argumentSpanThatHasAnnotation
                                .get(currentSentence.getId() + "_" + s.getStartIdx() + "_" + s.getEndIdx()) != null)
                        .collect(toList());
            }
            HashMap<String, ArrayList<ArgumentSpan>> roleArgPrediction = new HashMap<String, ArrayList<ArgumentSpan>>();
            if (!knownAnnotation && spans.size() == 0) {
                continue;
            }
            System.out.println(currentSentence.getRawText());
            for (int j = 0; j < spans.size(); j++) {
                HashMap<String, Double> roleProbPair = new HashMap<String, Double>();
                HashMap<String, String> roleVectorPair = new HashMap<String, String>();
                ArgumentSpan currentSpan = spans.get(j);
                if (isMulticlass) {
                    String roleLabel = "Multi";
                    FeatureExtractor fExtractor = fExtractors.get(roleLabel);
                    ArrayList<Integer> tokenIdx = currentSpan.getRoleIdx();
                    DependencyTree depTree = StanfordDepParserSingleton.getInstance()
                            .parse(currentSentence.getRawText());
                    // Check ada common ancestor yang berada di tokenIdx gak
                    DependencyNode headNode = depTree.getHeadNode(tokenIdx);
                    //for (int k = 0; k < tokenIdx.size(); k++) {
                    String rawVector = fExtractor.extractFeatureVectorValue(headNode.getId(), currentSentence,
                            currentSpan, false, isMulticlass);
                    //System.out.println(rawVector);
                    //liblinear.Linear.predictProbability(;, x, prob_estimates)
                    FeatureNode[] x = LibLinearWrapper.toFeatureNode(rawVector, models.get(roleLabel));
                    int prediction = (int) Linear.predict(models.get(roleLabel), x);
                    double probs[] = new double[fExtractor.multiClassLabel.size()];
                    Linear.predictProbability(models.get(roleLabel), x, probs);
                    Model m = models.get(roleLabel);
                    int[] labels = m.getLabels();
                    //int positiveIdx = labels[0] == 1 ? 0 : 1;
                    for (String label : fExtractor.multiClassLabel.keySet()) {
                        int labelID = fExtractor.multiClassLabel.get(label);
                        int probID = -1;
                        for (int k = 0; k < labels.length; k++) {
                            if (labels[k] == labelID) {
                                probID = k;
                                break;
                            }

                        }
                        roleProbPair.put(label, probs[probID]);
                    }

                    roleVectorPair.put(roleLabel, rawVector);
                    currentSpan.setRoleProbPair(roleProbPair);
                    currentSpan.setRoleFeatureVector(roleVectorPair);
                    currentSpan.predictRoleType(true);
                } else {
                    for (String roleLabel : classLabels) {
                        if (fExtractors.get(roleLabel) != null) {
                            FeatureExtractor fExtractor = fExtractors.get(roleLabel);
                            ArrayList<Integer> tokenIdx = currentSpan.getRoleIdx();
                            DependencyTree depTree = StanfordDepParserSingleton.getInstance()
                                    .parse(currentSentence.getRawText());
                            // Check ada common ancestor yang berada di tokenIdx gak
                            DependencyNode headNode = depTree.getHeadNode(tokenIdx);
                            //for (int k = 0; k < tokenIdx.size(); k++) {
                            String rawVector = fExtractor.extractFeatureVectorValue(headNode.getId(),
                                    currentSentence, currentSpan, false, isMulticlass);
                            //liblinear.Linear.predictProbability(;, x, prob_estimates)
                            FeatureNode[] x = LibLinearWrapper.toFeatureNode(rawVector, models.get(roleLabel));
                            int prediction = (int) Linear.predict(models.get(roleLabel), x);
                            double probs[] = new double[2];
                            Linear.predictProbability(models.get(roleLabel), x, probs);
                            Model m = models.get(roleLabel);
                            int[] labels = m.getLabels();
                            int positiveIdx = labels[0] == 1 ? 0 : 1;
                            roleProbPair.put(roleLabel, probs[positiveIdx]);
                            roleVectorPair.put(roleLabel, rawVector);
                        }
                    }
                    currentSpan.setRoleProbPair(roleProbPair);
                    currentSpan.setRoleFeatureVector(roleVectorPair);
                    currentSpan.normalizeProbScore();
                    currentSpan.predictRoleType(false);
                }

                // store in this in the hashMap
                if (roleArgPrediction.get(currentSpan.getRolePredicted()) != null) {
                    ArrayList<ArgumentSpan> predictedSpan = roleArgPrediction.get(currentSpan.getRolePredicted());
                    predictedSpan.add(currentSpan);
                    roleArgPrediction.put(currentSpan.getRolePredicted(), predictedSpan);
                } else {
                    ArrayList<ArgumentSpan> predictedSpan = new ArrayList<ArgumentSpan>();
                    predictedSpan.add(currentSpan);
                    roleArgPrediction.put(currentSpan.getRolePredicted(), predictedSpan);
                }
            }
            if (roleArgPrediction == null) {
                System.out.println("Something is going wrong here");
                System.exit(0);
            }
            currentSentence.setRoleArgPrediction(roleArgPrediction);
        }

        // populateProbabilityILP(procDataAnnArr);
        // make unique of the SAME arguments,  based on startID and endID
        if (testingFileName.contains("gold")) {
            FileUtil.serializeToFile(sentences, testingFileName.replace("gold", "predict"));
        } else {

            FileUtil.serializeToFile(sentences, testingFileName.replace("test.", "predict."));
        }
    }

    public void performPrediction() throws IOException {
        /*ArrayList<ArgProcessAnnotationData> procDataAnnArr = null; //;argAnnotationReader.getProcDataArr();
         for (int i = 0; i < procDataAnnArr.size(); i++) {
         ArgProcessAnnotationData currentProcData = procDataAnnArr.get(i);
         for (String roleName : classLabels) {
         if (fExtractors.get(roleName) != null) {
         FeatureExtractor fExtractor = fExtractors.get(roleName);
         ArrayList<Integer> tokenIdx = currentProcData.getRoleIdx(roleName);
         if (tokenIdx.size() == 0) {
         continue;
         }
         ArrayList<Integer> incorrectIdx = new ArrayList<Integer>();
         DependencyTree depTree = StanfordDepParserSingleton.getInstance().parse(currentProcData.getRawText());
         // Check ada common ancestor yang berada di tokenIdx gak
         DependencyNode headNode = depTree.getHeadNode(tokenIdx);
         //for (int k = 0; k < tokenIdx.size(); k++) {
         String rawVector = fExtractor.extractFeatureVectorValue(headNode.getId(), currentProcData, roleName, false);// IMPLEMENT THIS
         //liblinear.Linear.predictProbability(;, x, prob_estimates)
         FeatureNode[] x = LibLinearWrapper.toFeatureNode(rawVector, models.get(roleName));
         int prediction = liblinear.Linear.predict(models.get(roleName), x);
         double probs[] = new double[2];
         liblinear.Linear.predictProbability(models.get(roleName), x, probs);
         System.out.println("PREDICTION " + prediction);
         if (prediction == 1) {
         System.out.println("CORRECT");
         } else if (prediction == -1) {
         System.out.println("NOT CORRECT");
         //incorrectIdx.add(tokenIdx.get(k));
         currentProcData.clearRoleFiller(roleName);
         }
            
         } else {
         // clear role fillers
         currentProcData.clearRoleFiller(roleName);
         }
         }
         }
         // dumpFrameToFile
         ArgProcessAnnotationDataUtil.dumpRolePredictionToFile(procDataAnnArr, this.configFileName, this.predictionFileName);
         //ProcessFrameUtil.dumpFramesToFile(frames, GlobalV.PROJECT_DIR + "/data/predicted.tsv");*/
    }

    public static void main(String[] args) throws IOException, FileNotFoundException, ClassNotFoundException {
        /*SBURolePredict predictor = new SBURolePredict(GlobalV.PROJECT_DIR + "/data/undergoer_20testing.frame.tsv",
         GlobalV.PROJECT_DIR + "/data/undergoer_20testing.cleaned.tsv",
         GlobalV.PROJECT_DIR + "/data/modelDebug");*/
        /*SBURolePredict predictor = new SBURolePredict(GlobalV.PROJECT_DIR + "/data/training_w_pattern.tsv",
         GlobalV.PROJECT_DIR + "/configFrameFile/config.txt",
         GlobalV.PROJECT_DIR + "/data/model-03-11-2015-full",
         GlobalV.PROJECT_DIR + "/data/out/predicted.tsv");
         predictor.performPrediction();*/
    }
}