Java tutorial
/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ package tpt.dbweb.cat; import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintWriter; import java.io.StringWriter; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; import org.apache.commons.collections4.iterators.ReverseListIterator; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringEscapeUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.beust.jcommander.Parameter; import tpt.dbweb.cat.datatypes.EntityMention; import tpt.dbweb.cat.datatypes.MentionChains; import tpt.dbweb.cat.datatypes.MentionChains.Chain; import tpt.dbweb.cat.datatypes.TaggedText; import tpt.dbweb.cat.datatypes.iterators.CompareIterator; import tpt.dbweb.cat.datatypes.iterators.ComparePair; import tpt.dbweb.cat.datatypes.iterators.EntityMentionPos; import tpt.dbweb.cat.datatypes.iterators.EntityMentionPosIterator.PosType; import tpt.dbweb.cat.evaluation.ComparisonResult; import tpt.dbweb.cat.evaluation.EvaluationStatistics; import tpt.dbweb.cat.io.TaggedTextXMLReader; import tpt.dbweb.cat.tools.ExtractInitials; import tpt.dbweb.cat.tools.MentionChainAligner; import tpt.dbweb.cat.tools.Utility; /** * Compare one or more XML files with annotations to a goldstandard and output them as a self-contained XML file. * It uses src/main/resources/compare-template.xml to create the output file. Please change XML transformation, CSS and Javascript there. * * @author Thomas Rebele * */ public class Compare { private final static Logger log = LoggerFactory.getLogger(Compare.class); public enum InputFormat { CoNLL, XML }; /** * Command line options for compare */ public static class Options { @Parameter(description = "Input files, treat first as the gold standard") public List<String> input = new ArrayList<>(); @Parameter(names = "--format", description = "input format") public InputFormat inputFormat = InputFormat.CoNLL; @Parameter(names = "--out") public String outputFile = null; boolean replaceNewlineWithBR = false; /** * Only use the min mention for visualization */ boolean minOnly = true; /** * Transform the entities to a more human readable form (add string of first mention and chain number) */ public boolean humanReadableMentions = false; /** * remove non-mention-entities from the input */ public boolean filterNMEEntities = true; } private final Options options; public Compare(Options options) { this.options = options; } /** * Saves the evaluation of a mark (correct, missing, wrong, toomuch) and chain information, e.g. "(1" or "2" or "3)" */ private class MarkEval { String eval; String chainBefore; String chainAfter; } public static void compare(Options options, List<ComparisonResult> evaluations) throws IOException { if (options.outputFile != null && options.input != null && options.input.size() > 0) { Compare compare = new Compare(new Options()); List<Path> paths = options.input.stream().map(str -> Paths.get(str)).collect(Collectors.toList()); compare.compareXML(paths, Paths.get(options.outputFile), evaluations); } } /** * Do some cleanup on the text, e.g. removing unwanted entities * @param tt */ private void cleanUp(TaggedText tt) { // tt.mentions.removeIf(em -> // options.filterEntities.contains(em.entity)); if (options.filterNMEEntities) { tt.mentions.removeIf(em -> Utility.isNME(em.entity)); } tt.mentions.sort(null); } /** * Check whether we can accept the input, i.e. all the tagged texts have the same text. * @param files list of filenames to output more useful information to the user * @param tts list of tagged texts * @return true if tagged texts have the right format */ private boolean checkTaggedTexts(List<String> infos, List<TaggedText> tts) { // print message when article ids are not the same text for (int i = 1; i < tts.size(); i++) { TaggedText tt0 = tts.get(0), ttI = tts.get(i); if (!tt0.id.equals(ttI.id)) { StringBuilder sb = new StringBuilder(); sb.append("article id is not the same (" + infos.get(0) + ", id " + tt0.id + " and " + infos.get(i) + ", id " + ttI.id + ")"); sb.append("\n>>>"); sb.append(tt0.text); sb.append("\n<<<\n>>>"); sb.append(ttI.text); sb.append("\n<<<\n"); log.warn(sb.toString()); return false; } // print message when article texts are not the same if (!tt0.text.equals(ttI.text)) { if (log.isWarnEnabled()) { StringBuilder sb = new StringBuilder(); sb.append("text of article is not the same (" + infos.get(0) + ", id " + tt0.id + " and " + infos.get(i) + ", id " + ttI.id + ")"); sb.append(", common prefix: '"); int prefixLen = Utility.getCommonPrefixLength(tt0.text, ttI.text); sb.append(tt0.text.substring(0, prefixLen)); sb.append("'"); log.warn(sb.toString()); log.warn("1st text continues with " + tt0.text.substring(prefixLen, Math.min(prefixLen + 10, tt0.text.length()))); log.warn("2nd text continues with " + ttI.text.substring(prefixLen, Math.min(prefixLen + 10, ttI.text.length()))); log.warn("length 1st text: " + tt0.text.length()); log.warn("length 2nd text: " + ttI.text.length()); } return false; } } return true; } /** * Load XML files, compare them and write the output XML files to out * @param files * @param out * @param evaluations * @throws IOException */ public void compareXML(List<Path> files, Path out, List<ComparisonResult> evaluations) throws IOException { log.info("comparing {}; writing output to {}", files, out); List<Iterator<TaggedText>> ttIts = new ArrayList<>(); List<String> info = new ArrayList<>(); try { TaggedTextXMLReader ttxr = new TaggedTextXMLReader(); for (int i = 0; i < files.size(); i++) { ttIts.add(ttxr.iteratePath(files.get(i))); info.add(files.get(i).toString()); } compare(ttIts, info, out, evaluations); } catch (FileNotFoundException e) { log.error("file not found: {}", e.getMessage()); } } public void compare(List<Iterator<TaggedText>> ttIts, List<String> infos, Path out, List<ComparisonResult> evaluations) throws IOException { boolean docEvaluationNotFound = false; StringWriter sw = new StringWriter(); PrintWriter ps = new PrintWriter(sw); ps.append("<annotators>\n"); // load files and print annotator info for (int i = 0; i < ttIts.size(); i++) { ps.append("\t<annotator id='" + i + "' file='"); ps.append(StringEscapeUtils.escapeXml11(infos.get(i))); ps.append("'/>\n"); } ps.append("</annotators>\n"); // print evaluation List<Map<String, EvaluationStatistics>> evals = new ArrayList<>(); Map<String, EvaluationStatistics> eval; if (evaluations != null) { for (int i = 0; i < evaluations.size(); i++) { ComparisonResult combinedEvaluations = evaluations.get(i).combine(); eval = new TreeMap<>(); evals.add(eval); // type is macro / micro for (String type : combinedEvaluations.docidToMetricToResult.keySet()) { Map<String, EvaluationStatistics> metricToResult = combinedEvaluations.docidToMetricToResult .get(type); for (String metric : metricToResult.keySet()) { eval.put(metric + " (" + type + ")", metricToResult.get(metric)); } } } ps.print(printMetrics(evals)); evals.clear(); } // print comparison of articles while (ttIts.stream().allMatch(it -> it.hasNext())) { List<TaggedText> tts = ttIts.stream().map(it -> it.next()).collect(Collectors.toList()); tts.forEach(tt -> cleanUp(tt)); if (!checkTaggedTexts(infos, tts)) { break; } // do comparison and write to output ps.print(" <article id='"); ps.print(tts.get(0).id); ps.println("'>"); ps.print(compare(tts)); ps.println(); // print evaluation of article if (evaluations != null) { docEvaluationNotFound = true; for (int i = 0; i < evaluations.size(); i++) { eval = evaluations.get(i).docidToMetricToResult.get(tts.get(0).id); evals.add(eval); if (eval != null) { docEvaluationNotFound = false; } } if (docEvaluationNotFound == false) { ps.print(printMetrics(evals)); } else { } } else { docEvaluationNotFound = true; } if (docEvaluationNotFound) { log.warn("evaluation not found for {}", tts.get(0).id); } ps.println(" </article>"); } if (docEvaluationNotFound && evaluations != null && evaluations.size() > 0) { log.warn("available evaluations: {}", evaluations.get(0).docidToMetricToResult.keySet()); } // load template and replace <article/> String template = Utility.readResourceAsString("compare-template.xml"); sw.toString(); String output = template.replace("<article/>", sw.toString()); FileUtils.writeStringToFile(out.toFile(), output); } private String printMetrics(List<Map<String, EvaluationStatistics>> evaluations) { StringBuilder sb = new StringBuilder(); sb.append("<metrics>\n"); for (int i = 0; i < evaluations.size(); i++) { Map<String, EvaluationStatistics> map = evaluations.get(i); sb.append("<annotator id='" + (i + 1) + "'>\n"); for (String name : map.keySet()) { sb.append(" <metric name='" + name + "'"); EvaluationStatistics es = map.get(name); sb.append(" recall='" + es.getRecall() + "'"); sb.append(" precision='" + es.getPrecision() + "'/>\n"); } sb.append("</annotator>\n"); } sb.append("</metrics>\n"); return sb.toString(); } public String compare(List<TaggedText> tts) { StringBuilder builder = new StringBuilder(); List<List<EntityMention>> mentions = new ArrayList<>(); // track open marks and entity mentions List<String> openMarks = new ArrayList<>(); // list is overkill // replace mentions by their minimum for (int i = 0; i < tts.size(); i++) { if (options.minOnly) { mentions.add(new ArrayList<>()); for (EntityMention em : tts.get(i).mentions) { mentions.get(i).add(em.getMinMention()); } } else { mentions.add(tts.get(i).mentions); } } // chains for both documents List<MentionChains> chains = mentions.stream().map(eml -> new MentionChains(eml)).collect(Collectors.toList()); // align chains to chain0 for (int i = 1; i < chains.size(); i++) { Map<String, String> map = new MentionChainAligner().guessEntityMapGreedy(tts.get(0), tts.get(i)); int unmappedIdx = chains.get(0).entityToChain.size() + 1; for (String entity : chains.get(i).entityToChain.keySet()) { String mappedEntity = map.get(entity); Chain entityChainI = chains.get(i).entityToChain.get(entity); if (mappedEntity == null) { entityChainI.idx = unmappedIdx++; } else { Chain entityChain0 = chains.get(0).entityToChain.get(mappedEntity); if (entityChain0 == null) { entityChainI.idx = unmappedIdx++; } else { entityChainI.idx = entityChain0.idx; } } } } // generate new entity names for human output List<Map<String, String>> entityMentionToOutput = new ArrayList<>(); entityMentionToOutput.add(getEntityRenameMap(mentions.get(0), options.humanReadableMentions, null)); for (int i = 1; i < mentions.size(); i++) { entityMentionToOutput.add(getEntityRenameMap(mentions.get(i), options.humanReadableMentions, chains.get(0))); } // map chains to abbreviations Map<String, Set<String>> shortnameToEntryList = new HashMap<>(); for (int i = 0; i < mentions.size(); i++) { for (Entry<String, String> e : entityMentionToOutput.get(i).entrySet()) { String shortName = ExtractInitials.getInitials(e.getKey()); shortnameToEntryList.computeIfAbsent(shortName, -> new HashSet<>(1)).add(e.getKey()); } } Map<String, String> entryToShortname = new HashMap<>(); Map<String, String> shortnameToEntry = new TreeMap<>(); for (Entry<String, Set<String>> e : shortnameToEntryList.entrySet()) { int i = 0; for (String c : e.getValue()) { String shortname = e.getKey() + (e.getValue().size() <= 1 ? "" : (++i)); shortname = c; entryToShortname.put(c, shortname); shortnameToEntry.put(shortname, c); } } // output abbreviation legend builder.append("<entity-list>\n"); for (Entry<String, String> e : shortnameToEntry.entrySet()) { builder.append("<entry>"); builder.append(StringEscapeUtils.escapeXml10(e.getValue())); builder.append("</entry>\n"); } builder.append("</entity-list>\n"); // iterate over mentions builder.append("<content>"); CompareIterator cmpIt = new CompareIterator(tts.get(0).text, tts.get(0).id, mentions); ComparePair last = null; for (ComparePair pair : Utility.iterable(cmpIt)) { log.trace("{}, text {}", pair, tts.get(0).text.substring(pair.start, pair.end)); // escape span that was compared String escaped = StringEscapeUtils.escapeXml10(tts.get(0).text.substring(pair.start, pair.end)); if (options.replaceNewlineWithBR) { escaped = escaped.replace("\n\n\n", "<br/>"); escaped = escaped.replace("\n\n", "<br/>"); escaped = escaped.replace("\n", "<br/>"); } boolean hasEntity = false; List<EntityMention> principalMentions = new ArrayList<>(); for (int i = 0; i < mentions.size(); i++) { EntityMention em = pair.getPrincipalMention(i); // filter AIDA out-of-knowledge-base-entities if (em != null && "--OOKBE--".equals(em.entity)) { em = null; } hasEntity |= em != null; principalMentions.add(em); } // create mark tag List<MarkEval> evals = null; if (hasEntity) { openMarks.add("entities " + principalMentions); builder.append("<mark "); evals = new ArrayList<>(); boolean split = evaluateMark(last, pair, principalMentions, chains, evals); builder.append(" split='" + Boolean.toString(split) + "'"); // add entity and other information EntityMention em = principalMentions.get(0); printEntityAttributes(builder, "0", pair, em, entityMentionToOutput.get(0), entryToShortname); addChainInfo("0", evals.get(0), builder); builder.append(">"); // print out individual annotator evaluations for (int i = 1; i < mentions.size(); i++) { builder.append("<annotator index='" + i + "'"); em = principalMentions.get(i); printEntityAttributes(builder, "", pair, em, entityMentionToOutput.get(i), entryToShortname); if (evals.get(i).eval != null) { builder.append(" eval='" + evals.get(i).eval + "'"); addAnnotatorInfo(i, evals, principalMentions, builder); } addChainInfo(null, evals.get(i), builder); // Note: newline character introduces a space between a mark and its before chain annotations builder.append("/>\n"); } } // generate chain indices for super/subscript // doChainAnnotation(pair, chains, builder); // escape and print builder.append(escaped); // close mark tags while (openMarks.size() > 0) { openMarks.remove(openMarks.size() - 1); builder.append("</mark>\n"); } last = pair; } builder.append("</content>"); return builder.toString().trim(); } private void printEntityAttributes(StringBuilder builder, String attributeSuffix, ComparePair pair, EntityMention em, Map<String, String> entityMentionToOutput, Map<String, String> entryToShortname) { if (em != null) { String entity = StringEscapeUtils.escapeXml11(entityMentionToOutput.get(em.entity)); builder.append(" entity" + attributeSuffix + "='" + entity + "'"); String shortName = entryToShortname.get(em.entity); int length = pair.end - pair.start; if (shortName != null && shortName.length() > length + 5) { shortName = shortName.substring(0, length + 5) + ""; } builder.append(" short" + attributeSuffix + "='" + Utility.orElse(shortName, "[none]") + "'"); } else { //builder.append(" entity='-'"); builder.append(" short" + attributeSuffix + "='[none]'"); } } private void addAnnotatorInfo(int idx, List<MarkEval> evals, List<EntityMention> principalMentions, StringBuilder builder) { if (evals == null || evals.get(idx) == null) { return; } EntityMention emI = principalMentions.get(idx); if (emI != null && emI.info() != null) { for (Entry<String, String> info : emI.info().entrySet()) { builder.append(" " + info.getKey() + "='" + StringEscapeUtils.escapeXml10(info.getValue()) + "'"); } } return; } /** * * @param evals * @param principalMentions * @return true if mark should be splitted */ boolean evaluateMark(ComparePair lastPair, ComparePair pair, List<EntityMention> principalMentions, List<MentionChains> chains, List<MarkEval> evals) { MarkEval me = new MarkEval(); me.chainBefore = chainAnnotationAttr(0, lastPair, PosType.START, chains); me.chainAfter = chainAnnotationAttr(0, pair, PosType.END, chains); evals.add(me); EntityMention em0 = principalMentions.get(0); String principalEvaluation = null; for (int i = 1; i < principalMentions.size(); i++) { EntityMention emI = principalMentions.get(i); String eval = null; if (em0 != null && em0.entity != null) { if (emI == null || emI.entity == null) { eval = "missing"; } } if (emI != null && emI.entity != null) { if (em0 != null && emI.entity.equals(em0.entity)) { eval = "correct"; } else { if (em0 == null) { eval = "toomuch"; } else { eval = "wrong"; } } } if (eval == null) { eval = ""; } me = new MarkEval(); me.eval = eval; me.chainBefore = chainAnnotationAttr(i, lastPair, PosType.START, chains); me.chainAfter = chainAnnotationAttr(i, pair, PosType.END, chains); evals.add(me); if (principalEvaluation == null) { principalEvaluation = eval; } else if (!principalEvaluation.equals(eval)) { principalEvaluation = "split"; } } return "split".equals(principalEvaluation); } private void addChainInfo(String idx, MarkEval eval, StringBuilder builder) { if (idx == null) { idx = ""; } if (eval.chainBefore != null) { builder.append(" chain-before" + idx + "='" + eval.chainBefore + "'"); } if (eval.chainAfter != null) { builder.append(" chain-after" + idx + "='" + eval.chainAfter + "'"); } } /** * Generate "(chainidx" or "chainidx" or "chainidx)" strings * * @param docIdx * @param pair * @param chains * @return */ private String chainAnnotationAttr(int docIdx, ComparePair pair, PosType posType, List<MentionChains> chains) { if (pair == null || pair.getPos(docIdx) == null) { return null; } List<EntityMention> mentions = pair.getMentions(docIdx); if (mentions == null || mentions.size() == 0) { return null; } StringBuilder sb = new StringBuilder(); boolean printIntermediates = !pair.emps.get(docIdx).stream() .filter(emp -> (emp.posType == PosType.START || emp.posType == PosType.END)).findAny().isPresent(); for (EntityMentionPos emp : Utility.iterable(new ReverseListIterator<>(pair.emps.get(docIdx)))) { if (emp == null) { continue; } if (emp.posType != posType) { continue; } PosType pt = emp.posType; Chain c0 = chains.get(docIdx).mentionToChain.get(emp.em); if (c0 == null) { return null; } String chainStr = "" /*+ c0.idx*/; switch (pt) { case START: chainStr = "(" + chainStr; break; case END: chainStr = chainStr + ")"; break; case INTERMEDIATE: if (!printIntermediates) { chainStr = null; } break; default: log.warn("chainToStr cannot deal with pos type {} for compare pair", pt, pair); } if (chainStr != null) { sb.append(chainStr); } } return sb.toString(); } private Map<String, String> getEntityRenameMap(List<EntityMention> mentions0, boolean rename, MentionChains chains) { Map<String, String> entityMentionToOutput0 = new HashMap<>(); for (EntityMention m : mentions0) { String entity = m.entity; if (rename) { int idx = (entityMentionToOutput0.size() + 1); if (chains != null) { Chain c = chains.entityToChain.get(m); if (c != null) { idx = c.idx; } } int fidx = idx; entityMentionToOutput0.computeIfAbsent(entity, k -> StringEscapeUtils.escapeXml10(m.getMinMention().spanString()) + "; " + fidx + "; " + StringEscapeUtils.escapeXml10(entity)); } else { entityMentionToOutput0.computeIfAbsent(entity, k -> StringEscapeUtils.escapeXml10(entity)); } } return entityMentionToOutput0; } }