Util.SentenceUtil.java Source code

Java tutorial

Introduction

Here is the source code for Util.SentenceUtil.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package Util;

import com.google.common.reflect.TypeToken;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import edu.cmu.cs.lti.ark.fn.parsing.SemaforParseResult;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.Reader;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import qa.util.FileUtil;
import sbu.srl.datastructure.ArgumentSpan;
import sbu.srl.datastructure.ArgumentSpanDeserializer;
import sbu.srl.datastructure.ArgumentSpanSerializer;
import sbu.srl.datastructure.ILPSRLDataDeserializer;
import sbu.srl.datastructure.ILPSRLDataSerializer;
import sbu.srl.datastructure.JSONData;
import sbu.srl.datastructure.Sentence;
import sbu.srl.datastructure.SentenceDeserializer;
import sbu.srl.datastructure.SentenceSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.gson.stream.JsonReader;
import edu.cmu.cs.lti.ark.fn.parsing.SemaforParseResult.Frame;
import edu.cmu.cs.lti.ark.fn.parsing.SemaforParseResult.Frame.ScoredRoleAssignment;
import edu.cmu.cs.lti.ark.fn.parsing.SemaforParseResult.Frame.Span;
import edu.cmu.cs.lti.ark.util.ds.Pair;
import edu.uw.easysrl.main.ParseResult;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Set;
import static java.util.stream.Collectors.toList;
import scala.actors.threadpool.Arrays;

/**
 *
 * @author slouvan
 */
public class SentenceUtil {

    public static int counter = 0;

    public static ArrayList<JSONData> generateJSONData(Map<String, List<Sentence>> mapByProcess) {
        ArrayList<JSONData> ilpDataArr = new ArrayList<>();
        for (String process : mapByProcess.keySet()) {
            JSONData ilpDataItem = new JSONData();
            ilpDataItem.setProcessName(process);
            ilpDataItem.setSentence((ArrayList<Sentence>) mapByProcess.get(process));
            ilpDataArr.add(ilpDataItem);
        }
        return ilpDataArr;
    }

    public static void flushDataToJSON(ArrayList<JSONData> predictionData, String fileName, boolean predict)
            throws FileNotFoundException {
        final GsonBuilder gsonBuilder = new GsonBuilder();
        gsonBuilder.registerTypeAdapter(JSONData.class, new ILPSRLDataSerializer());
        gsonBuilder.registerTypeAdapter(Sentence.class, new SentenceSerializer(predict));
        gsonBuilder.registerTypeAdapter(ArgumentSpan.class, new ArgumentSpanSerializer());
        gsonBuilder.setPrettyPrinting();
        final Gson gson = gsonBuilder.create();
        //final String json = gson.toJson(data);
        //gson.to
        String jsonString = gson.toJson(predictionData, new TypeToken<ArrayList<JSONData>>() {
        }.getType());
        PrintWriter writer = new PrintWriter(fileName);
        writer.println(jsonString);
        writer.close();
    }

    public static ArrayList<JSONData> readJSONData(String fileName, boolean isPrediction)
            throws FileNotFoundException, IOException {
        final GsonBuilder gsonObj = new GsonBuilder();
        gsonObj.registerTypeAdapter(JSONData.class, new ILPSRLDataDeserializer());
        gsonObj.registerTypeAdapter(Sentence.class, new SentenceDeserializer(isPrediction));
        gsonObj.registerTypeAdapter(ArgumentSpan.class, new ArgumentSpanDeserializer(isPrediction));

        final Gson gsonG = gsonObj.create();
        Reader reader = new InputStreamReader(new FileInputStream(fileName));
        ArrayList<JSONData> jsonDataArr = gsonG.fromJson(reader,
                new com.google.gson.reflect.TypeToken<ArrayList<JSONData>>() {
                }.getType());

        reader.close();
        return jsonDataArr;
    }

    public static List<SemaforParseResult> readSemaforJSONdata(String fileName)
            throws FileNotFoundException, IOException {
        ObjectMapper mapper = new ObjectMapper();
        List<String> lines = Files.readAllLines(Paths.get(fileName));

        String jsonString = String.join(",", lines);
        jsonString = "[" + jsonString + "]";
        final SemaforParseResult[] parse = mapper.readValue(jsonString, SemaforParseResult[].class);
        return Arrays.asList(parse);

    }

    public static List<ParseResult> readEasySRLJSONdata(String fileName) throws FileNotFoundException {
        Gson gson = new Gson();
        Type listOfObject = new com.google.gson.reflect.TypeToken<List<ParseResult>>() {
        }.getType();

        String[] lines = FileUtil.readLinesFromFile(fileName);
        return gson.fromJson(String.join("\n", lines), listOfObject);
    }

    public static Pair<String, Double> getSEMAFORLabel(List<Frame.NamedSpanSet> spanSet, int startIdx, int endIdx) {
        double maxScore = Double.MIN_VALUE;
        String label = "NONE";
        for (int i = 0; i < spanSet.size(); i++) {
            Frame.NamedSpanSet currentFrameElmt = spanSet.get(i);
            List<Span> overlappedSpans = currentFrameElmt.spans.stream()
                    .filter(s -> s.start >= startIdx - 1 && s.start <= endIdx - 1).collect(toList());
            // iterate through the span, get the max score, return label
            for (Span span : overlappedSpans) {
                if (span.probScore > maxScore) {
                    maxScore = span.probScore;
                    label = currentFrameElmt.name;
                }
            }
        }

        return new Pair(label, maxScore);
        // handle NONE
    }

    public static Pair<String, Double> getEasySRLLabel(String text, int startIdx, int endIdx) {
        /*double maxScore = Double.MIN_VALUE;
         String label = "NONE";
         for (int i = 0; i < spanSet.size(); i++) {
         Frame.NamedSpanSet currentFrameElmt = spanSet.get(i);
         List<Span> overlappedSpans = currentFrameElmt.spans.stream().filter(s -> s.start >= startIdx - 1 && s.start <= endIdx - 1).collect(toList());
         // iterate through the span, get the max score, return label
         for (Span span : overlappedSpans) {
         if (span.probScore > maxScore) {
         maxScore = span.probScore;
         label = currentFrameElmt.name;
         }
         }
         }
            
         return new Pair(label, maxScore);
         // handle NONE*/
        return null;
    }

    public static void transformSemaforPrediction(String srlPredictionJSON, String rawSentenceSemafor,
            String semaforJSONPrediction, String outFile) throws FileNotFoundException, IOException {
        List<String> rawSentences = Arrays.asList(FileUtil.readLinesFromFile(rawSentenceSemafor));
        List<SemaforParseResult> semaforResults = readSemaforJSONdata(semaforJSONPrediction);
        List<JSONData> srlJsonData = SentenceUtil.readJSONData(srlPredictionJSON, false);
        List<JSONData> semaforJSONManipulated = (List<JSONData>) FileUtil.deepClone(srlJsonData);

        Set<String> labels = JSONDataUtil.getAllUniqueRoleLabelFromJSON(srlJsonData);
        labels.add("NONE");
        for (int i = 0; i < srlJsonData.size(); i++) {
            JSONData currentData = (JSONData) srlJsonData.get(i);
            JSONData currentSemData = (JSONData) semaforJSONManipulated.get(i);
            System.out.println(currentData.getProcessName());
            ArrayList<Sentence> sentencesInProcess = currentData.getSentence();
            ArrayList<Sentence> sentenceInProcessSemafor = currentSemData.getSentence();
            for (int j = 0; j < sentencesInProcess.size(); j++) {
                Sentence goldSentence = sentencesInProcess.get(j);
                Sentence semSentence = sentenceInProcessSemafor.get(j);
                System.out.println(goldSentence.getRawText());
                // Find corresponding prediction frame for this sentence in SEMAFOR
                int sentId = ArrUtil.getMatchIdx(rawSentences, goldSentence.getRawText());
                SemaforParseResult semaforPrediction = semaforResults.get(sentId);
                Frame targetFrame = null;
                List<Frame> targetFrames = semaforPrediction.frames.stream()
                        .filter(f -> f.target.name.equals(currentData.getProcessName())).collect(toList());
                targetFrame = (targetFrames != null && targetFrames.size() > 0) ? targetFrames.get(0) : null;
                System.out.println(sentId);
                System.out.println(goldSentence.getRawText());

                if (targetFrame != null) {
                    ScoredRoleAssignment roleAssignments = targetFrame.annotationSets.get(0);
                    List<Frame.NamedSpanSet> spanSet = roleAssignments.frameElements;
                    ArrayList<ArgumentSpan> spans = goldSentence.getPredictedArgumentSpanJSON();
                    ArrayList<ArgumentSpan> semSpans = semSentence.getPredictedArgumentSpanJSON();
                    for (int k = 0; k < spans.size(); k++) {
                        ArgumentSpan span = spans.get(k);
                        ArgumentSpan semSpan = semSpans.get(k);

                        // Collect all the predictions from SEMAFOR where startIdx and EndIdx intersects
                        int startIdx = span.getStartIdxJSON();
                        int endIdx = span.getEndIdxJSON();
                        Pair<String, Double> labelScorePair = getSEMAFORLabel(spanSet, startIdx, endIdx);
                        System.out.println("Text : " + span.getTextJSON() + " ROLE SEMAFOR : "
                                + labelScorePair.first + " ROLE SRL : " + span.getRolePredicted());
                        semSpan.setRolePredicted(labelScorePair.first); // change role predicted to semafor's label
                        HashMap<String, Double> probPair = semSpan.getRoleProbPair();
                        System.out.println("Size before : " + probPair.size());
                        probPair.put(labelScorePair.first, labelScorePair.second);
                        for (String role : labels) {
                            if (!role.equalsIgnoreCase(labelScorePair.first)) {
                                probPair.remove(role);
                            }
                        }
                        System.out.println("Size : " + probPair.size());
                    }
                } else {
                    System.out.println(counter++);
                    ArrayList<ArgumentSpan> spans = goldSentence.getPredictedArgumentSpanJSON();
                    ArrayList<ArgumentSpan> semSpans = semSentence.getPredictedArgumentSpanJSON();
                    for (int k = 0; k < spans.size(); k++) {
                        ArgumentSpan span = spans.get(k);
                        ArgumentSpan semSpan = semSpans.get(k);

                        // Collect all the predictions from SEMAFOR where startIdx and EndIdx intersects
                        int startIdx = span.getStartIdxJSON();
                        int endIdx = span.getEndIdxJSON();
                        //Pair<String, Double> labelScorePair = getSEMAFORLabel(spanSet, startIdx, endIdx);
                        //System.out.println("Text : " + span.getTextJSON() + " ROLE SEMAFOR : " + labelScorePair.first + " ROLE SRL : " + span.getRolePredicted());
                        semSpan.setRolePredicted("NONE"); // change role predicted to semafor's label
                        HashMap<String, Double> probPair = semSpan.getRoleProbPair();
                        System.out.println("Size before : " + probPair.size());
                        probPair.put("NONE", 1.0);
                        for (String role : labels) {
                            if (!role.equalsIgnoreCase("NONE")) {
                                probPair.remove(role);
                            }
                        }
                        System.out.println("Size : " + probPair.size());
                    }

                }
            }
        }
        flushDataToJSON(new ArrayList<JSONData>(semaforJSONManipulated), outFile, true);
        // FLUSH
    }

    public static void transformEasySrlPrediction(String srlPredictionJSON, String rawSentenceEasySRL,
            String easySrlJSONPrediction, String outFile) throws FileNotFoundException, IOException {
        List<String> rawSentences = Arrays.asList(FileUtil.readLinesFromFile(rawSentenceEasySRL));
        List<ParseResult> easySrlResults = readEasySRLJSONdata(easySrlJSONPrediction);
        List<JSONData> srlJsonData = SentenceUtil.readJSONData(srlPredictionJSON, false);
        List<JSONData> easySrlJSONManipulated = (List<JSONData>) FileUtil.deepClone(srlJsonData);

        Set<String> labels = JSONDataUtil.getAllUniqueRoleLabelFromJSON(srlJsonData);
        labels.add("NONE");
        for (int i = 0; i < srlJsonData.size(); i++) {
            JSONData currentData = (JSONData) srlJsonData.get(i);
            JSONData currentSemData = (JSONData) easySrlJSONManipulated.get(i);
            System.out.println(currentData.getProcessName());
            ArrayList<Sentence> sentencesInProcess = currentData.getSentence();
            ArrayList<Sentence> sentenceInProcessSemafor = currentSemData.getSentence();
            for (int j = 0; j < sentencesInProcess.size(); j++) {
                Sentence goldSentence = sentencesInProcess.get(j);
                Sentence semSentence = sentenceInProcessSemafor.get(j);
                System.out.println(goldSentence.getRawText());

                int sentId = ArrUtil.getMatchIdx(rawSentences, goldSentence.getRawText());
                ParseResult easySrlPrediction = easySrlResults.get(sentId);

                if (easySrlPrediction.getParseScore() != -1.0) {
                    ArrayList<ArgumentSpan> spans = goldSentence.getPredictedArgumentSpanJSON();
                    ArrayList<ArgumentSpan> semSpans = semSentence.getPredictedArgumentSpanJSON();
                    for (int k = 0; k < spans.size(); k++) {
                        ArgumentSpan span = spans.get(k);
                        ArgumentSpan semSpan = semSpans.get(k);

                        // Collect all the predictions from SEMAFOR where startIdx and EndIdx intersects
                        int startIdx = span.getStartIdxJSON();
                        int endIdx = span.getEndIdxJSON();
                        String text = span.getTextJSON();
                        Pair<String, Double> labelScorePair = getEasySRLLabel(text, startIdx, endIdx);
                        System.out.println("Text : " + span.getTextJSON() + " ROLE EASYSRL : "
                                + labelScorePair.first + " ROLE SRL : " + span.getRolePredicted());
                        semSpan.setRolePredicted(labelScorePair.first); // change role predicted to semafor's label
                        HashMap<String, Double> probPair = semSpan.getRoleProbPair();
                        System.out.println("Size before : " + probPair.size());
                        probPair.put(labelScorePair.first, labelScorePair.second);
                        for (String role : labels) {
                            if (!role.equalsIgnoreCase(labelScorePair.first)) {
                                probPair.remove(role);
                            }
                        }
                        System.out.println("Size : " + probPair.size());
                    }
                } else {
                    System.out.println(counter++);
                    ArrayList<ArgumentSpan> spans = goldSentence.getPredictedArgumentSpanJSON();
                    ArrayList<ArgumentSpan> semSpans = semSentence.getPredictedArgumentSpanJSON();
                    for (int k = 0; k < spans.size(); k++) {
                        ArgumentSpan span = spans.get(k);
                        ArgumentSpan semSpan = semSpans.get(k);

                        // Collect all the predictions from SEMAFOR where startIdx and EndIdx intersects
                        int startIdx = span.getStartIdxJSON();
                        int endIdx = span.getEndIdxJSON();
                        //Pair<String, Double> labelScorePair = getSEMAFORLabel(spanSet, startIdx, endIdx);
                        //System.out.println("Text : " + span.getTextJSON() + " ROLE SEMAFOR : " + labelScorePair.first + " ROLE SRL : " + span.getRolePredicted());
                        semSpan.setRolePredicted("NONE"); // change role predicted to semafor's label
                        HashMap<String, Double> probPair = semSpan.getRoleProbPair();
                        System.out.println("Size before : " + probPair.size());
                        probPair.put("NONE", 1.0);
                        for (String role : labels) {
                            if (!role.equalsIgnoreCase("NONE")) {
                                probPair.remove(role);
                            }
                        }
                        System.out.println("Size : " + probPair.size());
                    }
                }
            }
        }
        flushDataToJSON(new ArrayList<JSONData>(easySrlJSONManipulated), outFile, true);
        // FLUSH
    }

    public static void main(String[] args) throws FileNotFoundException, IOException {
        /*String crossValDir = "/home/slouvan/NetBeansProjects/SRL-Integrated/data/cross-val-08-01-2016-byprocess-fold-process";
         for (int i = 1; i <= 5; i++) {
         transformSemaforPrediction(crossValDir+"/fold-"+i+"/test/test.srlout.json",
         crossValDir+"/fold-"+i+"/test/cv."+i+".test.sentence.sbu",
         crossValDir+"/fold-"+i+"/test/semaforOut",
         crossValDir+"/fold-"+i+"/test/test.semaforpredict.json");
         System.out.println("COUNTER:"+SentenceUtil.counter);
         }*/
        //List<SemaforParseResult>  results = readSemaforJSONdata("/home/slouvan/NetBeansProjects/semafor/output.txt");

        List<ParseResult> results = readEasySRLJSONdata("/home/slouvan/NetBeansProjects/EasySRL/easySrlOut");
        for (ParseResult result : results) {
            System.out.println(result);
        }

    }
}