com.paolodragone.wsn.disambiguation.WsnTermDisambiguator.java Source code

Java tutorial

Introduction

Here is the source code for com.paolodragone.wsn.disambiguation.WsnTermDisambiguator.java

Source

/*
 * Copyright Paolo Dragone 2014
 *
 * This file is part of WiktionarySemanticNetwork.
 *
 * WiktionarySemanticNetwork is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * WiktionarySemanticNetwork is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with WiktionarySemanticNetwork.  If not, see <http://www.gnu.org/licenses/>.
 */

package com.paolodragone.wsn.disambiguation;

import com.google.common.base.Strings;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;
import com.paolodragone.util.Counter;
import com.paolodragone.util.DMath;
import com.paolodragone.util.nlp.POS;
import com.paolodragone.wsn.entities.SemanticEdge;
import com.paolodragone.wsn.entities.Sense;
import com.paolodragone.wsn.entities.Synonym;
import com.paolodragone.wsn.entities.Term;
import com.paolodragone.wsn.util.LexicalContexts;

import java.util.*;

/**
 * TODO: Possibly introduce word blacklist:
 * generally
 * belong
 * manner
 * object
 * where
 * separate
 * relate, relating
 * consist, consisting
 * vary
 * originate, originating
 * extent
 * tending
 * require, required
 * consider, considered
 * include, including
 * operation
 * involve, involves
 * technique
 * whereby
 * area
 * production
 */

public class WsnTermDisambiguator {

    private static final int maxTargetSenses = 8;
    private static final double SamePosCoefficient = 1;
    private static final double DifferentPosCoefficient = 0.01;
    private int contextDepth;
    private Map<String, Double> lemmaWeightMap;
    private ListMultimap<String, Sense> wordSensesMap;

    //    private Map<Term, Boolean> validTermsMap = new HashMap<>();
    //    private Table<Term, Sense, List<String>> firstLevelContext = HashBasedTable.create();
    //    private Map<Term, List<Sense>> termTargetSenseMap = new HashMap<>();
    private Map<Term, Sense> disambiguatedTerms = new HashMap<>();
    private Map<Term, Double> disambiguatedTermConfidenceMap = new HashMap<>();

    public WsnTermDisambiguator(ListMultimap<String, Sense> wordSensesMap, Map<String, Double> lemmaWeightMap,
            int contextDepth) {
        this.wordSensesMap = wordSensesMap;
        this.lemmaWeightMap = lemmaWeightMap;
        this.contextDepth = contextDepth;
    }

    public SemanticEdge disambiguateTerm(Term term) {
        List<Sense> targetSenses = getTargetSenses(term);
        /*if (targetSenses.size() > maxTargetSenses) {
        return null;
        }*/
        Map<Sense, Double> targetSenseSimilarityMap = new HashMap<>();
        for (Sense targetSense : targetSenses) {
            Sense sense = term.getParentSense();
            List<String> termContext = getTermContext(term, sense);
            List<String> targetSenseContext = getTermContext(term, targetSense);
            double similarity = weightedSimilarityScore(termContext, targetSenseContext);
            similarity = adjustTargetSenseSimilarity(term, targetSense, targetSenses, similarity);
            targetSenseSimilarityMap.put(targetSense, similarity);
        }

        Sense bestSense = getBestSense(targetSenseSimilarityMap);
        if (bestSense != null && targetSenseSimilarityMap.get(bestSense) > 0.0) {
            double confidence = getConfidence(bestSense, targetSenseSimilarityMap);

            if (disambiguatedTerms.containsKey(term)) {
                double formerConfidence = disambiguatedTermConfidenceMap.get(term);
                if (confidence <= formerConfidence) {
                    Sense formerBestSense = disambiguatedTerms.get(term);
                    SemanticEdge semanticEdge = new SemanticEdge();
                    semanticEdge.setTermId(term.getId());
                    semanticEdge.setTargetSenseId(formerBestSense.getId());
                    semanticEdge.setConfidence(formerConfidence);
                    return semanticEdge;
                }
            }

            disambiguatedTerms.put(term, bestSense);
            disambiguatedTermConfidenceMap.put(term, confidence);
            SemanticEdge semanticEdge = new SemanticEdge();
            semanticEdge.setTermId(term.getId());
            semanticEdge.setTargetSenseId(bestSense.getId());
            semanticEdge.setConfidence(confidence);
            return semanticEdge;
        }
        return null;
    }

    private double adjustTargetSenseSimilarity(Term term, Sense targetSense, List<Sense> targetSenses,
            double similarity) {
        return similarity *
        //               getSenseNumberCoefficient(targetSense.getNumber(), targetSenses.size()) *
                getPosSimilarityCoefficient(term.getPos(), targetSense.getPos());
    }

    private double getPosSimilarityCoefficient(POS termPos, POS targetSensePos) {
        return termPos.equals(targetSensePos) ? SamePosCoefficient : DifferentPosCoefficient;
    }

    private double getSenseNumberCoefficient(int senseNumber, int size) {
        return DMath.log(size + 1.0, 1.0 + 1.0 / senseNumber);
    }

    public List<Sense> getTargetSenses(Term term) {
        //        if (termTargetSenseMap.containsKey(term)) {
        //            return termTargetSenseMap.get(term);
        //        }

        List<Sense> targetSenses = new ArrayList<>(wordSensesMap.get(term.getLemma().toLowerCase()));
        //        validTargetSensesFilter(term, targetSenses);
        lexicalContextFilter(term.getParentSense(), targetSenses);
        //        termTargetSenseMap.put(term, targetSenses);

        return targetSenses;
    }

    /*private void validTargetSensesFilter (Term term, Collection<Sense> targetSenses) {
    Iterator<Sense> targetSenseIterator = targetSenses.iterator();
    while (targetSenseIterator.hasNext()) {
        Sense targetSense = targetSenseIterator.next();
        if (!Terms.isValidTargetSense(term, targetSense)) {
            targetSenseIterator.remove();
        }
    }
    }*/

    private void lexicalContextFilter(Sense sense, Collection<Sense> targetSenses) {
        String lexicalContext = sense.getLexicalContext();
        if (Strings.isNullOrEmpty(lexicalContext)) {
            return;
        }

        if (LexicalContexts.isValidLexicalContext(lexicalContext)
                && anyWithLexicalContext(lexicalContext, targetSenses)) {
            Set<String> lexicalContextParts = LexicalContexts.getLexicalContextParts(lexicalContext);
            LexicalContexts.filterValidLexicalContextParts(lexicalContextParts);
            Iterator<Sense> targetSenseIterator = targetSenses.iterator();
            while (targetSenseIterator.hasNext()) {
                Sense targetSense = targetSenseIterator.next();
                String targetSenseLexicalContext = targetSense.getLexicalContext();
                Set<String> targetSenseLexicalContextParts = LexicalContexts
                        .getLexicalContextParts(targetSenseLexicalContext);
                LexicalContexts.filterValidLexicalContextParts(targetSenseLexicalContextParts);
                if (Sets.intersection(lexicalContextParts, targetSenseLexicalContextParts).isEmpty()) {
                    targetSenseIterator.remove();
                }
            }
        }
    }

    private boolean anyWithLexicalContext(String lexicalContext, Collection<Sense> senses) {
        Set<String> lexicalContextParts1 = LexicalContexts.getLexicalContextParts(lexicalContext);
        for (Sense sense : senses) {
            Set<String> lexicalContextParts2 = LexicalContexts.getLexicalContextParts(sense.getLexicalContext());
            if (!Sets.intersection(lexicalContextParts1, lexicalContextParts2).isEmpty()) {
                return true;
            }
        }
        return false;
    }

    private List<String> getTermContext(Term term, Sense sense) {
        List<String> context = new ArrayList<>();
        context.addAll(getTermContext(term, sense, contextDepth));
        //        filterWeightedTerms(context);
        return context;
    }

    private List<String> getTermContext(Term term, Sense sense, int depth) {
        List<String> context = new ArrayList<>();
        if (depth > 0) {

            context.addAll(getFirstLevelContext(term, sense));

            String lemma = term.getLemma();
            List<Term> glossTerms = sense.getGlossTerms();
            for (Term glossTerm : glossTerms) {
                String glossLemma = glossTerm.getLemma();
                if (!glossLemma.equals(lemma) && disambiguatedTerms.containsKey(glossTerm)) {
                    context.addAll(getTermContext(term, disambiguatedTerms.get(glossTerm), depth - 1));
                }
            }

            // TODO: Added target senses of synonyms without disambiguation
            for (Synonym synonym : sense.getSynonyms()) {
                String targetWord = synonym.getTargetWord();
                if (!targetWord.equals(lemma)) {
                    List<Sense> targetSenses = wordSensesMap.get(targetWord);
                    for (Sense targetSense : targetSenses) {
                        context.addAll(getTermContext(term, targetSense, depth - 1));
                    }
                }
            }
        }
        return context;
    }

    private List<String> getFirstLevelContext(Term term, Sense sense) {
        //        if (firstLevelContext.contains(term, sense)) {
        //            return firstLevelContext.get(term, sense);
        //        }

        String lemma = term.getLemma();

        List<String> context = new ArrayList<>();
        context.addAll(getFirstLevelLexicalContext(sense));
        context.addAll(getFirstLevelSynonymContext(lemma, sense));
        context.addAll(getFirstLevelGlossContext(lemma, sense));

        //        firstLevelContext.put(term, sense, context);

        return context;
    }

    private List<String> getFirstLevelLexicalContext(Sense sense) {
        List<String> context = new ArrayList<>();
        if (!Strings.isNullOrEmpty(sense.getLexicalContext())) {
            context.addAll(LexicalContexts.getLexicalContextParts(sense.getLexicalContext()));
        }
        return context;
    }

    /*
        private void filterWeightedTerms (Set<String> context) {
    PdCollections.filter(context, lemmaWeightMap::containsKey);
        }
    */

    private List<String> getFirstLevelSynonymContext(String lemma, Sense sense) {
        List<String> context = new ArrayList<>();
        // Synonyms level 1
        List<Synonym> synonyms = sense.getSynonyms();
        if (!synonyms.isEmpty()) {
            for (Synonym synonym : synonyms) {
                String targetWord = synonym.getTargetWord();
                if (!targetWord.equals(lemma)) {
                    context.add(targetWord.toLowerCase());
                }
            }
        }
        return context;
    }

    private List<String> getFirstLevelGlossContext(String lemma, Sense sense) {
        List<String> context = new ArrayList<>();
        List<Term> glossTerms = sense.getGlossTerms();
        for (Term glossTerm : glossTerms) {
            String glossLemma = glossTerm.getLemma();
            if (!glossLemma.equals(lemma) /*&& isValidTerm(glossTerm)*/) {
                context.add(glossLemma.toLowerCase());
            }
        }
        return context;
    }

    private double weightedSimilarityScore(List<String> terms1, List<String> terms2) {
        return smoothedJaccardSimilarity(new LinkedHashSet<>(terms1), new LinkedHashSet<>(terms2));
        //        return smoothedCosineSimilarity(terms1, terms2);
    }

    private double smoothedCosineSimilarity(List<String> terms1, List<String> terms2) {
        Counter<String> counterTerms1 = new Counter<>();
        Counter<String> counterTerms2 = new Counter<>();

        terms1.stream().filter(lemmaWeightMap::containsKey).forEach(counterTerms1::add);
        terms2.stream().filter(lemmaWeightMap::containsKey).forEach(counterTerms2::add);

        Set<String> intersection = Sets.intersection(counterTerms1.getItemSet(), counterTerms2.getItemSet());

        double dotProd = 0.0;
        for (String term : intersection) {
            dotProd += Math.pow(lemmaWeightMap.get(term), 2) * counterTerms1.getCount(term)
                    * counterTerms2.getCount(term);
        }

        double magnitude1 = getMagnitude(counterTerms1.getItemSet(), counterTerms1.getItemCounts());
        double magnitude2 = getMagnitude(counterTerms2.getItemSet(), counterTerms2.getItemCounts());

        return (dotProd + 1) / ((magnitude1 * magnitude2) + lemmaWeightMap.size());
    }

    private double getMagnitude(Set<String> terms, Map<String, Integer> counts) {
        double magnitude = 0.0;
        for (String term : terms) {
            magnitude += Math.pow(lemmaWeightMap.get(term) * counts.get(term), 2);
        }
        return Math.sqrt(magnitude);
    }

    private double smoothedJaccardSimilarity(Set<String> terms1, Set<String> terms2) {
        Set<String> intersection = Sets.intersection(terms1, terms2);
        double intersectionWeight = 0;
        for (String term : intersection) {
            if (lemmaWeightMap.containsKey(term)) {
                intersectionWeight += lemmaWeightMap.get(term);
            }
        }
        //return intersectionWeight;
        // TODO: trying again with Jaccard smoothed
        double unionWeight = 0;
        Set<String> union = Sets.union(terms1, terms2);
        for (String term : union) {
            if (lemmaWeightMap.containsKey(term)) {
                unionWeight += lemmaWeightMap.get(term);
            }
        }

        return (intersectionWeight + 1.0) / (unionWeight + lemmaWeightMap.size());
    }

    private Sense getBestSense(Map<Sense, Double> targetSenseSimilarityMap) {
        Sense bestSense = null;
        double bestSimilarity = 0;
        for (Map.Entry<Sense, Double> targetSenseSimilarityEntry : targetSenseSimilarityMap.entrySet()) {
            Sense sense = targetSenseSimilarityEntry.getKey();
            double similarity = targetSenseSimilarityEntry.getValue();
            if (bestSense == null || similarity > bestSimilarity) {
                bestSimilarity = similarity;
                bestSense = sense;
            }
        }
        return bestSense;
    }

    private Sense getSecondBestSense(Map<Sense, Double> targetSenseSimilarityMap, Sense bestSense) {
        Map<Sense, Double> targetSenseSimilarityMapCopy = new HashMap<>(targetSenseSimilarityMap);
        targetSenseSimilarityMapCopy.remove(bestSense);
        return getBestSense(targetSenseSimilarityMapCopy);
    }

    private double getConfidence(Sense bestSense, Map<Sense, Double> targetSenseSimilarityMap) {
        Sense secondBestSense = getSecondBestSense(targetSenseSimilarityMap, bestSense);
        double bestSimilarity = targetSenseSimilarityMap.get(bestSense);
        double secondBestSimilarity = 0.0;
        if (secondBestSense != null) {
            secondBestSimilarity = targetSenseSimilarityMap.get(secondBestSense);
        }
        double sumSimilarity = getSumSimilarity(targetSenseSimilarityMap);
        return (bestSimilarity - secondBestSimilarity) / sumSimilarity;
    }

    private double getSumSimilarity(Map<Sense, Double> targetSenseSimilarityMap) {
        double sum = 0;
        for (double similarity : targetSenseSimilarityMap.values()) {
            sum += similarity;
        }
        return sum;
    }

    /*private boolean isValidTerm (Term term) {
    boolean valid;
    if (!validTermsMap.containsKey(term)) {
        valid = Terms.isValidTerm(term);
        validTermsMap.put(term, valid);
    } else {
        valid = validTermsMap.get(term);
    }
    return valid;
    }*/
}