org.dragoneronca.nlp.wol.disambiguation.SenseSolver.java Source code

Java tutorial

Introduction

Here is the source code for org.dragoneronca.nlp.wol.disambiguation.SenseSolver.java

Source

/*
 * Copyright Paolo Dragone 2014.
 * Copyright Alessandro Ronca 2014.
 *
 * This file is part of Wiktionary Ontology.
 *
 * Wiktionary Ontology is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Wiktionary Ontology is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Wiktionary Ontology. If not, see <http://www.gnu.org/licenses/>.
 */

package org.dragoneronca.nlp.wol.disambiguation;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import org.apache.commons.configuration.PropertiesConfiguration;
import org.apache.log4j.Logger;
import org.dragoneronca.nlp.wol.WolConfiguration;
import org.dragoneronca.util.concurrent.Consumer;
import org.dragoneronca.util.graphs.Node;
import org.dragoneronca.util.graphs.Path;
import org.dragoneronca.util.graphs.PathFinder;
import org.dragoneronca.util.graphs.PathScorer;

import java.util.*;
import java.util.concurrent.BlockingQueue;

/**
 * This class computes a better probability distribution for the possible senses of the terms of a
 * given sense.
 * <p/>
 * It is an implementation of the CQC algorithm, it performs graph visits searching for cycles and
 * quasi-cycles for the current sense. Then it updates the probabilities of the outgoing edges of
 * the sense according to the scores of the found paths.
 * <p/>
 * It implements the <tt>Consumer</tt> interface so that it can be linked to a producer of senses,
 * performing a streaming computation.
 * <p/>
 * It implements the <tt>Runnable</tt> interface so that different instances can be executed in
 * parallel, since the algorithm in its whole is data parallel wrt to senses.
 *
 * @author Paolo Dragone
 * @author Alessandro Ronca
 */
public class SenseSolver implements Consumer<LightSense>, Runnable {

    public static final Logger LOG = Logger.getLogger(SenseSolver.class);

    private static final double EPSILON;
    private static final int MAX_DEPTH;

    static {
        PropertiesConfiguration properties = WolConfiguration.getInstance().getConfiguration("environment");

        EPSILON = properties.getDouble("sense_solver.epsilon");
        MAX_DEPTH = properties.getInt("sense_solver.max_depth");
    }

    private final PathScorer cycleScorer;
    private final PathScorer quasiCycleScorer;
    private boolean executed = false;
    private BlockingQueue<LightSense> queue;
    private int processedTerms = 0;
    private int convergedTerms = 0;
    private int impossibleToDisambiguate = 0;

    /**
     * It constructs a <tt>SenseSolver</tt> given scorers for cycles and quasi-cycles.
     *
     * @param cycleScorer      a scorer for cycles.
     * @param quasiCycleScorer a scorer for quasi-cycles.
     */
    public SenseSolver(PathScorer cycleScorer, PathScorer quasiCycleScorer) {
        this.cycleScorer = cycleScorer;
        this.quasiCycleScorer = quasiCycleScorer;
    }

    @Override
    public void setInputQueue(BlockingQueue<LightSense> queue) {
        this.queue = queue;
    }

    /**
     * It returns the total number of processed terms.
     *
     * @return number of terms.
     */
    public int getProcessedTerms() {
        if (!executed) {
            throw new IllegalStateException("Not yet executed");
        } else {
            return processedTerms;
        }
    }

    /**
     * It returns the number of terms whose probability distribution over senses has not changed
     * after the execution of this algorithm.
     *
     * @return number of terms.
     */
    public int getConvergedTerms() {
        if (!executed) {
            throw new IllegalStateException("Not yet executed");
        } else {
            return convergedTerms;
        }
    }

    /**
     * It returns the number of terms for which the algorithm has not been able to perform
     * disambiguation.
     *
     * @return number of terms.
     */
    public int getImpossibleToDisambiguate() {
        return impossibleToDisambiguate;
    }

    @Override
    public void run() {
        if (executed) {
            return;
        } else {
            executed = true;
        }

        LightSense sense;
        try {
            while (!(sense = queue.take()).isEndOfStream()) {
                disambiguateSense(sense);
            }
        } catch (InterruptedException e) {
            LOG.warn("Exception while consuming", e);
        }
    }

    private void disambiguateSense(LightSense sense) {
        Set<LightSemanticEdge> outEdges = sense.getOutEdges();

        // group semantic edges by referred term
        Multimap<Integer, LightSemanticEdge> wordMap = LinkedHashMultimap.create();
        for (LightSemanticEdge lightSemanticEdge : outEdges) {
            wordMap.put(lightSemanticEdge.getTargetWordHash(), lightSemanticEdge);
        }

        // disambiguate each term
        for (int wordHash : wordMap.keySet()) {
            processedTerms++;
            if (disambiguateSenseTerm(sense, wordHash, wordMap.get(wordHash))) {
                convergedTerms++;
            }
        }

        // update weights in the light wol graph
        ArrayList<LightSemanticEdge> semanticEdgeList = new ArrayList<>(outEdges);
        Collections.sort(semanticEdgeList, new Comparator<LightSemanticEdge>() {
            @Override
            public int compare(LightSemanticEdge lightSemanticEdge, LightSemanticEdge lightSemanticEdge2) {
                return Integer.compare(lightSemanticEdge.getId(), lightSemanticEdge2.getId());
            }
        });

        EdgeList outEdgesOf = sense.getLightWolGraph().getOutEdgesOf(sense.getId());
        for (int i = 0; i < outEdges.size(); i++) {
            outEdgesOf.setWeight(i, semanticEdgeList.get(i).getWeight());
        }
    }

    private boolean disambiguateSenseTerm(LightSense sense, int wordHash, Collection<LightSemanticEdge> edges) {

        Set<LightSense> targetSenses = mapEdgesToSenses(edges);

        // find cycles
        HashSet<Path> cycles = new HashSet<>();
        for (LightSense targetSense : targetSenses) {
            if (!targetSense.equals(sense)) {
                HashSet<LightSense> sinks = Sets.newHashSet(sense);
                PathFinder cycleFinder = new WolPathFinder(targetSense, sinks, MAX_DEPTH, new HashSet<Integer>());
                cycleFinder.run();
                cycles.addAll(cycleFinder.getResult());
            }
        }

        // find quasi-cycles
        HashSet<Node> sinks = new HashSet<>();
        for (LightSense targetSense : targetSenses) {
            if (!targetSense.equals(sense)) {
                sinks.add(targetSense);
            }
        }
        PathFinder quasiCycleFinder = new WolPathFinder(sense, sinks, MAX_DEPTH, Sets.newHashSet(wordHash));
        quasiCycleFinder.run();
        Set<Path> quasiCycles = quasiCycleFinder.getResult();

        return updateProbabilities(edges, cycles, quasiCycles);
    }

    private Set<LightSense> mapEdgesToSenses(Collection<LightSemanticEdge> edges) {
        HashSet<LightSense> targetSenses = new HashSet<>();
        for (LightSemanticEdge edge : edges) {
            targetSenses.add(edge.getDestination());
        }
        return targetSenses;
    }

    /*
     * takes the outgoing semanticEdges of a term, it computes cycles and quasi-cycles for each
     * semantic edge, updating its probability.
     */
    private boolean updateProbabilities(Collection<LightSemanticEdge> edges, Set<Path> cycles,
            Set<Path> quasiCycles) {

        // group cycles by the first sense
        Multimap<Node, Path> cycleMap = ArrayListMultimap.create();
        for (Path cycle : cycles) {
            cycleMap.put(cycle.getPathOrigin(), cycle);
        }

        // group quasi-cycles by the last sense
        Multimap<Node, Path> quasiCycleMap = ArrayListMultimap.create();
        for (Path quasiCycle : quasiCycles) {
            quasiCycleMap.put(quasiCycle.getPathDestination(), quasiCycle);
        }

        if (cycles.isEmpty() && quasiCycles.isEmpty()) {
            impossibleToDisambiguate++;
            return true;
        }

        double normalizationSum = 0;
        HashMap<LightSemanticEdge, Double> scoreMap = new HashMap<>();
        for (LightSemanticEdge semanticEdge : edges) {

            // compute scores for the cycles including this semantic edge
            Collection<Path> targetCycles = cycleMap.get(semanticEdge.getDestination());
            double totalCycleScore = 0;
            for (Path targetCycle : targetCycles) {
                totalCycleScore += cycleScorer.getScore(targetCycle);
            }

            // compute scores for the quasi-cycles including this semantic edge
            Collection<Path> targetQuasiCycles = quasiCycleMap.get(semanticEdge.getDestination());
            double totalQuasiCycleScore = 0;
            for (Path targetQuasiCycle : targetQuasiCycles) {
                totalQuasiCycleScore += quasiCycleScorer.getScore(targetQuasiCycle);
            }

            // compute the score for this semantic edge
            double totalEdgeScore = totalCycleScore + totalQuasiCycleScore;
            scoreMap.put(semanticEdge, totalEdgeScore);
            normalizationSum += totalEdgeScore;
        }

        if (normalizationSum == 0) {
            impossibleToDisambiguate++;
            return true;
        }

        // normalize and update weights in the light semantic edges
        boolean convergence = true;
        for (LightSemanticEdge edge : edges) {
            double normScore = scoreMap.get(edge) / normalizationSum;
            convergence &= Math.abs(normScore - edge.getWeight()) < EPSILON;
            edge.setWeight(normScore);
        }
        return convergence;
    }

}