com.joliciel.talismane.parser.TransitionBasedParserImpl.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.talismane.parser.TransitionBasedParserImpl.java

Source

///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2012 Assaf Urieli
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see <http://www.gnu.org/licenses/>.
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.parser;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.TreeMap;
import java.util.Map.Entry;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.joliciel.talismane.TalismaneSession;
import com.joliciel.talismane.machineLearning.ClassificationObserver;
import com.joliciel.talismane.machineLearning.Decision;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.features.FeatureResult;
import com.joliciel.talismane.machineLearning.features.FeatureService;
import com.joliciel.talismane.machineLearning.features.RuntimeEnvironment;
import com.joliciel.talismane.parser.features.ParseConfigurationFeature;
import com.joliciel.talismane.parser.features.ParserRule;
import com.joliciel.talismane.posTagger.PosTagSequence;
import com.joliciel.talismane.posTagger.PosTaggedToken;
import com.joliciel.talismane.tokeniser.TokenSequence;
import com.joliciel.talismane.utils.PerformanceMonitor;

/**
 * A non-deterministic parser implementing transition based parsing,
 * using a Shift-Reduce algorithm.<br/>
 * See Nivre 2008 for details on the algorithm for the deterministic case.</br>
 * @author Assaf Urieli
 *
 */
class TransitionBasedParserImpl implements TransitionBasedParser {
    private static final Log LOG = LogFactory.getLog(TransitionBasedParserImpl.class);
    private static final Log LOG_FEATURES = LogFactory
            .getLog(TransitionBasedParserImpl.class.getName() + ".features");
    private static final PerformanceMonitor MONITOR = PerformanceMonitor
            .getMonitor(TransitionBasedParserImpl.class);
    private static final double MIN_PROB_TO_STORE = 0.0001;
    private static final DecimalFormat df = new DecimalFormat("0.0000");
    private int beamWidth = 1;
    private boolean earlyStop = false;

    private Set<ParseConfigurationFeature<?>> parseFeatures;

    private ParserServiceInternal parserServiceInternal;
    private FeatureService featureService;
    private DecisionMaker<Transition> decisionMaker;
    private TransitionSystem transitionSystem;
    private ParseComparisonStrategy parseComparisonStrategy = new BufferSizeComparisonStrategy();

    private List<ClassificationObserver<Transition>> observers = new ArrayList<ClassificationObserver<Transition>>();
    private int maxAnalysisTimePerSentence = 60;
    private int minFreeMemory = 64;
    private static final int KILOBYTE = 1024;

    private List<ParserRule> parserRules;
    private List<ParserRule> parserPositiveRules;
    private List<ParserRule> parserNegativeRules;

    public TransitionBasedParserImpl(DecisionMaker<Transition> decisionMaker, TransitionSystem transitionSystem,
            Set<ParseConfigurationFeature<?>> parseFeatures, int beamWidth) {
        super();
        this.decisionMaker = decisionMaker;
        this.transitionSystem = transitionSystem;
        this.parseFeatures = parseFeatures;
        this.beamWidth = beamWidth;
    }

    @Override
    public ParseConfiguration parseSentence(PosTagSequence posTagSequence) {
        List<PosTagSequence> posTagSequences = new ArrayList<PosTagSequence>();
        posTagSequences.add(posTagSequence);
        List<ParseConfiguration> parseConfigurations = this.parseSentence(posTagSequences);
        ParseConfiguration parseConfiguration = parseConfigurations.get(0);
        return parseConfiguration;
    }

    @Override
    public List<ParseConfiguration> parseSentence(List<PosTagSequence> posTagSequences) {
        MONITOR.startTask("parseSentence");
        try {
            long startTime = (new Date()).getTime();
            int maxAnalysisTimeMilliseconds = maxAnalysisTimePerSentence * 1000;
            int minFreeMemoryBytes = minFreeMemory * KILOBYTE;

            TokenSequence tokenSequence = posTagSequences.get(0).getTokenSequence();

            TreeMap<Integer, PriorityQueue<ParseConfiguration>> heaps = new TreeMap<Integer, PriorityQueue<ParseConfiguration>>();

            PriorityQueue<ParseConfiguration> heap0 = new PriorityQueue<ParseConfiguration>();
            for (PosTagSequence posTagSequence : posTagSequences) {
                // add an initial ParseConfiguration for each postag sequence
                ParseConfiguration initialConfiguration = this.getParserServiceInternal()
                        .getInitialConfiguration(posTagSequence);
                initialConfiguration.setScoringStrategy(decisionMaker.getDefaultScoringStrategy());
                heap0.add(initialConfiguration);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Adding initial posTagSequence: " + posTagSequence);
                }
            }
            heaps.put(0, heap0);
            PriorityQueue<ParseConfiguration> backupHeap = null;

            PriorityQueue<ParseConfiguration> finalHeap = null;
            PriorityQueue<ParseConfiguration> terminalHeap = new PriorityQueue<ParseConfiguration>();
            while (heaps.size() > 0) {
                Entry<Integer, PriorityQueue<ParseConfiguration>> heapEntry = heaps.pollFirstEntry();
                PriorityQueue<ParseConfiguration> currentHeap = heapEntry.getValue();
                int currentHeapIndex = heapEntry.getKey();
                if (LOG.isTraceEnabled()) {
                    LOG.trace("##### Polling next heap: " + heapEntry.getKey() + ", size: "
                            + heapEntry.getValue().size());
                }

                boolean finished = false;
                // systematically set the final heap here, just in case we exit "naturally" with no more heaps
                finalHeap = heapEntry.getValue();
                backupHeap = new PriorityQueue<ParseConfiguration>();

                // we jump out when either (a) all tokens have been attached or (b) we go over the max alloted time
                ParseConfiguration topConf = currentHeap.peek();
                if (topConf.isTerminal()) {
                    LOG.trace("Exiting with terminal heap: " + heapEntry.getKey() + ", size: "
                            + heapEntry.getValue().size());
                    finished = true;
                }

                if (earlyStop && terminalHeap.size() >= beamWidth) {
                    LOG.debug(
                            "Early stop activated and terminal heap contains " + beamWidth + " entries. Exiting.");
                    finalHeap = terminalHeap;
                    finished = true;
                }

                long analysisTime = (new Date()).getTime() - startTime;
                if (maxAnalysisTimePerSentence > 0 && analysisTime > maxAnalysisTimeMilliseconds) {
                    LOG.info("Parse tree analysis took too long for sentence: " + tokenSequence.getText());
                    LOG.info("Breaking out after " + maxAnalysisTimePerSentence + " seconds.");
                    finished = true;
                }

                if (minFreeMemory > 0) {
                    long freeMemory = Runtime.getRuntime().freeMemory();
                    if (freeMemory < minFreeMemoryBytes) {
                        LOG.info("Not enough memory left to parse sentence: " + tokenSequence.getText());
                        LOG.info("Min free memory (bytes):" + minFreeMemoryBytes);
                        LOG.info("Current free memory (bytes): " + freeMemory);
                        finished = true;
                    }
                }

                if (finished) {
                    break;
                }

                // limit the breadth to K
                int maxSequences = currentHeap.size() > this.beamWidth ? this.beamWidth : currentHeap.size();

                int j = 0;
                while (currentHeap.size() > 0) {
                    ParseConfiguration history = currentHeap.poll();
                    if (LOG.isTraceEnabled()) {
                        LOG.trace("### Next configuration on heap " + heapEntry.getKey() + ":");
                        LOG.trace(history.toString());
                        LOG.trace("Score: " + df.format(history.getScore()));
                        LOG.trace(history.getPosTagSequence());
                    }

                    List<Decision<Transition>> decisions = new ArrayList<Decision<Transition>>();

                    // test the positive rules on the current configuration
                    boolean ruleApplied = false;
                    if (parserPositiveRules != null) {
                        MONITOR.startTask("check rules");
                        try {
                            for (ParserRule rule : parserPositiveRules) {
                                if (LOG.isTraceEnabled()) {
                                    LOG.trace("Checking rule: " + rule.toString());
                                }
                                RuntimeEnvironment env = this.featureService.getRuntimeEnvironment();
                                FeatureResult<Boolean> ruleResult = rule.getCondition().check(history, env);
                                if (ruleResult != null && ruleResult.getOutcome()) {
                                    Decision<Transition> positiveRuleDecision = TalismaneSession
                                            .getTransitionSystem().createDefaultDecision(rule.getTransition());
                                    decisions.add(positiveRuleDecision);
                                    positiveRuleDecision.addAuthority(rule.getCondition().getName());
                                    ruleApplied = true;
                                    if (LOG.isTraceEnabled()) {
                                        LOG.trace("Rule applies. Setting transition to: "
                                                + rule.getTransition().getCode());
                                    }
                                    break;
                                }
                            }
                        } finally {
                            MONITOR.endTask("check rules");
                        }
                    }

                    if (!ruleApplied) {
                        // test the features on the current configuration
                        List<FeatureResult<?>> parseFeatureResults = new ArrayList<FeatureResult<?>>();
                        MONITOR.startTask("feature analyse");
                        try {
                            for (ParseConfigurationFeature<?> feature : this.parseFeatures) {
                                MONITOR.startTask(feature.getName());
                                try {
                                    RuntimeEnvironment env = this.featureService.getRuntimeEnvironment();
                                    FeatureResult<?> featureResult = feature.check(history, env);
                                    if (featureResult != null)
                                        parseFeatureResults.add(featureResult);
                                } finally {
                                    MONITOR.endTask(feature.getName());
                                }
                            }
                            if (LOG_FEATURES.isTraceEnabled()) {
                                for (FeatureResult<?> featureResult : parseFeatureResults) {
                                    LOG_FEATURES.trace(featureResult.toString());
                                }
                            }
                        } finally {
                            MONITOR.endTask("feature analyse");
                        }

                        // evaluate the feature results using the decision maker
                        MONITOR.startTask("make decision");
                        try {
                            decisions = this.decisionMaker.decide(parseFeatureResults);

                            for (ClassificationObserver<Transition> observer : this.observers) {
                                observer.onAnalyse(history, parseFeatureResults, decisions);
                            }

                            List<Decision<Transition>> decisionShortList = new ArrayList<Decision<Transition>>(
                                    decisions.size());
                            for (Decision<Transition> decision : decisions) {
                                if (decision.getProbability() > MIN_PROB_TO_STORE)
                                    decisionShortList.add(decision);
                            }
                            decisions = decisionShortList;
                        } finally {
                            MONITOR.endTask("make decision");
                        }

                        // apply the negative rules
                        Set<Transition> eliminatedTransitions = new HashSet<Transition>();
                        if (parserNegativeRules != null) {
                            MONITOR.startTask("check negative rules");
                            try {
                                for (ParserRule rule : parserNegativeRules) {
                                    if (LOG.isTraceEnabled()) {
                                        LOG.trace("Checking negative rule: " + rule.toString());
                                    }
                                    RuntimeEnvironment env = this.featureService.getRuntimeEnvironment();
                                    FeatureResult<Boolean> ruleResult = rule.getCondition().check(history, env);
                                    if (ruleResult != null && ruleResult.getOutcome()) {
                                        eliminatedTransitions.addAll(rule.getTransitions());
                                        if (LOG.isTraceEnabled()) {
                                            for (Transition eliminatedTransition : rule.getTransitions())
                                                LOG.trace("Rule applies. Eliminating transition: "
                                                        + eliminatedTransition.getCode());
                                        }
                                    }
                                }

                                if (eliminatedTransitions.size() > 0) {
                                    List<Decision<Transition>> decisionShortList = new ArrayList<Decision<Transition>>();
                                    for (Decision<Transition> decision : decisions) {
                                        if (!eliminatedTransitions.contains(decision.getOutcome())) {
                                            decisionShortList.add(decision);
                                        } else {
                                            LOG.trace("Eliminating decision: " + decision.toString());
                                        }
                                    }
                                    if (decisionShortList.size() > 0) {
                                        decisions = decisionShortList;
                                    } else {
                                        LOG.debug("All decisions eliminated! Restoring original decisions.");
                                    }
                                }
                            } finally {
                                MONITOR.endTask("check negative rules");
                            }
                        }
                    } // has a positive rule been applied?

                    boolean transitionApplied = false;
                    // add new configuration to the heap, one for each valid transition
                    MONITOR.startTask("heap sort");
                    try {
                        // Why apply all decisions here? Why not just the top N (where N = beamwidth)?
                        // Answer: because we're not always adding solutions to the same heap
                        // And yet: a decision here can only do one of two things: process a token (heap+1000), or add a non-processing transition (heap+1)
                        // So, if we've already applied N decisions of each type, we should be able to stop
                        for (Decision<Transition> decision : decisions) {
                            Transition transition = decision.getOutcome();
                            if (LOG.isTraceEnabled())
                                LOG.trace("Outcome: " + transition.getCode() + ", " + decision.getProbability());

                            if (transition.checkPreconditions(history)) {
                                transitionApplied = true;
                                ParseConfiguration configuration = this.parserServiceInternal
                                        .getConfiguration(history);
                                if (decision.isStatistical())
                                    configuration.addDecision(decision);
                                transition.apply(configuration);

                                int nextHeapIndex = parseComparisonStrategy.getComparisonIndex(configuration)
                                        * 1000;
                                if (configuration.isTerminal()) {
                                    nextHeapIndex = Integer.MAX_VALUE;
                                } else {
                                    while (nextHeapIndex <= currentHeapIndex)
                                        nextHeapIndex++;
                                }

                                PriorityQueue<ParseConfiguration> nextHeap = heaps.get(nextHeapIndex);
                                if (nextHeap == null) {
                                    if (configuration.isTerminal())
                                        nextHeap = terminalHeap;
                                    else
                                        nextHeap = new PriorityQueue<ParseConfiguration>();
                                    heaps.put(nextHeapIndex, nextHeap);
                                    if (LOG.isTraceEnabled())
                                        LOG.trace("Created heap with index: " + nextHeapIndex);
                                }
                                nextHeap.add(configuration);
                                if (LOG.isTraceEnabled()) {
                                    LOG.trace("Added configuration with score " + configuration.getScore()
                                            + " to heap: " + nextHeapIndex + ", total size: " + nextHeap.size());
                                }

                                configuration.clearMemory();
                            } else {
                                if (LOG.isTraceEnabled())
                                    LOG.trace("Cannot apply transition: doesn't meet pre-conditions");
                                // just in case the we run out of both heaps and analyses, we build this backup heap
                                backupHeap.add(history);
                            } // does transition meet pre-conditions?
                        } // next transition
                    } finally {
                        MONITOR.endTask("heap sort");
                    }

                    if (transitionApplied) {
                        j++;
                    } else {
                        LOG.trace("No transitions could be applied: not counting this history as part of the beam");
                    }

                    // beam width test
                    if (j == maxSequences)
                        break;
                } // next history   
            } // next atomic index

            // return the best sequences on the heap
            List<ParseConfiguration> bestConfigurations = new ArrayList<ParseConfiguration>();
            int i = 0;

            if (finalHeap.isEmpty())
                finalHeap = backupHeap;

            while (!finalHeap.isEmpty()) {
                bestConfigurations.add(finalHeap.poll());
                i++;
                if (i >= this.getBeamWidth())
                    break;
            }
            if (LOG.isDebugEnabled()) {
                for (ParseConfiguration finalConfiguration : bestConfigurations) {
                    LOG.debug(df.format(finalConfiguration.getScore()) + ": " + finalConfiguration.toString());
                    LOG.debug("Pos tag sequence: " + finalConfiguration.getPosTagSequence());
                    LOG.debug("Transitions: " + finalConfiguration.getTransitions());
                    LOG.debug("Decisions: " + finalConfiguration.getDecisions());
                    if (LOG.isTraceEnabled()) {
                        StringBuilder sb = new StringBuilder();
                        for (Decision<Transition> decision : finalConfiguration.getDecisions()) {
                            sb.append(" * ");
                            sb.append(df.format(decision.getProbability()));
                        }
                        sb.append(" root ");
                        sb.append(finalConfiguration.getTransitions().size());
                        LOG.trace(sb.toString());

                        sb = new StringBuilder();
                        sb.append(" * PosTag sequence score ");
                        sb.append(df.format(finalConfiguration.getPosTagSequence().getScore()));
                        sb.append(" = ");
                        for (PosTaggedToken posTaggedToken : finalConfiguration.getPosTagSequence()) {
                            sb.append(" * ");
                            sb.append(df.format(posTaggedToken.getDecision().getProbability()));
                        }
                        sb.append(" root ");
                        sb.append(finalConfiguration.getPosTagSequence().size());
                        LOG.trace(sb.toString());

                        sb = new StringBuilder();
                        sb.append(" * Token sequence score = ");
                        sb.append(df.format(finalConfiguration.getPosTagSequence().getTokenSequence().getScore()));
                        LOG.trace(sb.toString());

                    }
                }
            }
            return bestConfigurations;
        } finally {
            MONITOR.endTask("parseSentence");
        }
    }

    @Override
    public int getBeamWidth() {
        return beamWidth;
    }

    public ParserServiceInternal getParserServiceInternal() {
        return parserServiceInternal;
    }

    public void setParserServiceInternal(ParserServiceInternal parserServiceInternal) {
        this.parserServiceInternal = parserServiceInternal;
    }

    @Override
    public void addObserver(ClassificationObserver<Transition> observer) {
        this.observers.add(observer);
    }

    public TransitionSystem getTransitionSystem() {
        return transitionSystem;
    }

    public void setTransitionSystem(TransitionSystem transitionSystem) {
        this.transitionSystem = transitionSystem;
    }

    public int getMaxAnalysisTimePerSentence() {
        return maxAnalysisTimePerSentence;
    }

    public void setMaxAnalysisTimePerSentence(int maxAnalysisTimePerSentence) {
        this.maxAnalysisTimePerSentence = maxAnalysisTimePerSentence;
    }

    public int getMinFreeMemory() {
        return minFreeMemory;
    }

    public void setMinFreeMemory(int minFreeMemory) {
        this.minFreeMemory = minFreeMemory;
    }

    @Override
    public void setParserRules(List<ParserRule> parserRules) {
        this.parserRules = parserRules;
        this.parserPositiveRules = new ArrayList<ParserRule>();
        this.parserNegativeRules = new ArrayList<ParserRule>();
        for (ParserRule rule : parserRules) {
            if (rule.isNegative())
                parserNegativeRules.add(rule);
            else
                parserPositiveRules.add(rule);
        }
    }

    public List<ParserRule> getParserRules() {
        return parserRules;
    }

    public FeatureService getFeatureService() {
        return featureService;
    }

    public void setFeatureService(FeatureService featureService) {
        this.featureService = featureService;
    }

    public ParseComparisonStrategy getParseComparisonStrategy() {
        return parseComparisonStrategy;
    }

    public void setParseComparisonStrategy(ParseComparisonStrategy parseComparisonStrategy) {
        this.parseComparisonStrategy = parseComparisonStrategy;
    }

    public boolean isEarlyStop() {
        return earlyStop;
    }

    public void setEarlyStop(boolean earlyStop) {
        this.earlyStop = earlyStop;
    }

    @Override
    public Set<ParseConfigurationFeature<?>> getParseFeatures() {
        return parseFeatures;
    }

    @Override
    public void setParseFeatures(Set<ParseConfigurationFeature<?>> parseFeatures) {
        this.parseFeatures = parseFeatures;
    }

}