Java tutorial
/////////////////////////////////////////////////////////////////////////////// //Copyright (C) 2012 Assaf Urieli // //This file is part of Talismane. // //Talismane is free software: you can redistribute it and/or modify //it under the terms of the GNU Affero General Public License as published by //the Free Software Foundation, either version 3 of the License, or //(at your option) any later version. // //Talismane is distributed in the hope that it will be useful, //but WITHOUT ANY WARRANTY; without even the implied warranty of //MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //GNU Affero General Public License for more details. // //You should have received a copy of the GNU Affero General Public License //along with Talismane. If not, see <http://www.gnu.org/licenses/>. ////////////////////////////////////////////////////////////////////////////// package com.joliciel.talismane.parser; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import com.joliciel.talismane.filters.Sentence; import com.joliciel.talismane.machineLearning.Decision; import com.joliciel.talismane.machineLearning.GeometricMeanScoringStrategy; import com.joliciel.talismane.machineLearning.RankingSolution; import com.joliciel.talismane.machineLearning.ScoringStrategy; import com.joliciel.talismane.machineLearning.Solution; import com.joliciel.talismane.machineLearning.features.Feature; import com.joliciel.talismane.machineLearning.features.FeatureResult; import com.joliciel.talismane.machineLearning.features.RuntimeEnvironment; import com.joliciel.talismane.posTagger.PosTag; import com.joliciel.talismane.posTagger.PosTagSequence; import com.joliciel.talismane.posTagger.PosTaggedToken; import com.joliciel.talismane.posTagger.PosTaggedTokenLeftToRightComparator; final class ParseConfigurationImpl implements ParseConfigurationInternal { private static final Log LOG = LogFactory.getLog(ParseConfigurationImpl.class); /** * */ private static final long serialVersionUID = 1L; private PosTagSequence posTagSequence; private double score; private double rankingScore; private boolean scoreCalculated = false; private boolean useGeometricMeanForProbs = true; private Deque<PosTaggedToken> buffer; private Deque<PosTaggedToken> stack; private List<Transition> transitions; private Set<DependencyArc> dependencies; private Map<PosTaggedToken, DependencyArc> governingDependencyMap = null; private Map<PosTaggedToken, List<PosTaggedToken>> leftDependentMap = null; private Map<PosTaggedToken, List<PosTaggedToken>> rightDependentMap = null; private Map<PosTaggedToken, List<PosTaggedToken>> dependentMap = null; private Map<PosTaggedToken, Map<String, List<PosTaggedToken>>> dependentByLabelMap = null; private Map<PosTaggedToken, Transition> dependentTransitionMap = new HashMap<PosTaggedToken, Transition>(); private ParserServiceInternal parserServiceInternal; private DependencyNode parseTree = null; private List<Decision<Transition>> decisions = new ArrayList<Decision<Transition>>(); private int lastProbApplied = 0; private List<Solution> underlyingSolutions = new ArrayList<Solution>(); @SuppressWarnings("rawtypes") private ScoringStrategy scoringStrategy; private List<List<FeatureResult<?>>> incrementalFeatureResults = new ArrayList<List<FeatureResult<?>>>(); private Map<String, FeatureResult<?>> featureCache = new HashMap<String, FeatureResult<?>>(); private long createDate = System.currentTimeMillis(); public ParseConfigurationImpl(PosTagSequence posTagSequence) { super(); this.posTagSequence = posTagSequence; PosTaggedToken rootToken = posTagSequence.prependRoot(); this.underlyingSolutions.add(this.posTagSequence); this.buffer = new ArrayDeque<PosTaggedToken>(posTagSequence.size()); for (PosTaggedToken posTaggedToken : posTagSequence) this.buffer.add(posTaggedToken); this.buffer.remove(rootToken); this.stack = new ArrayDeque<PosTaggedToken>(); this.stack.push(rootToken); this.dependencies = new TreeSet<DependencyArc>(); this.transitions = new ArrayList<Transition>(); this.scoringStrategy = new GeometricMeanScoringStrategy<Transition>(); } public ParseConfigurationImpl(ParseConfiguration history) { super(); ParseConfigurationInternal iHistory = (ParseConfigurationInternal) history; this.transitions = new ArrayList<Transition>(history.getTransitions()); this.dependencies = new TreeSet<DependencyArc>(iHistory.getDependenciesInternal()); this.posTagSequence = history.getPosTagSequence(); posTagSequence.prependRoot(); this.underlyingSolutions.add(this.posTagSequence); this.buffer = new ArrayDeque<PosTaggedToken>(history.getBuffer()); this.stack = new ArrayDeque<PosTaggedToken>(history.getStack()); this.dependentTransitionMap = new HashMap<PosTaggedToken, Transition>( ((ParseConfigurationInternal) history).getDependentTransitionMap()); this.decisions = new ArrayList<Decision<Transition>>(history.getDecisions()); this.lastProbApplied = (((ParseConfigurationInternal) history).getLastProbApplied()); this.scoringStrategy = history.getScoringStrategy(); } @Override public PosTagSequence getPosTagSequence() { return this.posTagSequence; } @SuppressWarnings("unchecked") @Override public double getScore() { if (!scoreCalculated) { score = this.scoringStrategy.calculateScore(this); scoreCalculated = true; } return score; } public double getRankingScore() { return rankingScore; } public void setRankingScore(double rankingScore) { this.rankingScore = rankingScore; } @Override public Deque<PosTaggedToken> getBuffer() { return this.buffer; } @Override public Deque<PosTaggedToken> getStack() { return this.stack; } @Override public int compareTo(ParseConfiguration o) { // order by descending score if possible, otherwise by create date, otherwise by hash code if (this == o) return 0; else if (this.getScore() < o.getScore()) return 1; else if (this.getScore() > o.getScore()) return -1; else if (o instanceof ParseConfigurationInternal) return new Long(this.getCreateDate() - ((ParseConfigurationInternal) o).getCreateDate()).intValue(); else return o.hashCode() - this.hashCode(); } @Override public boolean isTerminal() { return this.buffer.isEmpty(); } public List<Transition> getTransitions() { return transitions; } public Set<DependencyArc> getDependenciesInternal() { return dependencies; } public Set<DependencyArc> getDependencies() { return dependencies; } public Set<DependencyArc> getRealDependencies() { Set<DependencyArc> realDependencies = new TreeSet<DependencyArc>(); for (DependencyArc arc : dependencies) { if (arc.getHead().getTag().equals(PosTag.ROOT_POS_TAG) && (arc.getLabel() == null || arc.getLabel().length() == 0)) { // do nothing } else { realDependencies.add(arc); } } return realDependencies; } @Override public PosTaggedToken getHead(PosTaggedToken dependent) { this.updateDependencyMaps(); DependencyArc arc = this.governingDependencyMap.get(dependent); PosTaggedToken head = null; if (arc != null) head = arc.getHead(); return head; } @Override public DependencyArc getGoverningDependency(PosTaggedToken dependent) { this.updateDependencyMaps(); DependencyArc arc = this.governingDependencyMap.get(dependent); return arc; } public Transition getTransition(DependencyArc arc) { PosTaggedToken dependent = arc.getDependent(); Transition transition = this.dependentTransitionMap.get(dependent); return transition; } @Override public List<PosTaggedToken> getLeftDependents(PosTaggedToken head) { this.updateDependencyMaps(); List<PosTaggedToken> dependents = this.leftDependentMap.get(head); return dependents; } @Override public List<PosTaggedToken> getRightDependents(PosTaggedToken head) { this.updateDependencyMaps(); List<PosTaggedToken> dependents = this.rightDependentMap.get(head); return dependents; } @Override public List<PosTaggedToken> getDependents(PosTaggedToken head) { this.updateDependencyMaps(); List<PosTaggedToken> dependentList = this.dependentMap.get(head); return dependentList; } public List<PosTaggedToken> getDependents(PosTaggedToken head, String label) { this.updateDependencyMaps(); List<PosTaggedToken> deps = null; Map<String, List<PosTaggedToken>> labelMap = this.dependentByLabelMap.get(head); if (labelMap != null) { deps = labelMap.get(label); } if (deps == null) deps = new ArrayList<PosTaggedToken>(0); return deps; } void updateDependencyMaps() { if (this.governingDependencyMap == null) { this.governingDependencyMap = new HashMap<PosTaggedToken, DependencyArc>(); this.rightDependentMap = new HashMap<PosTaggedToken, List<PosTaggedToken>>(); this.leftDependentMap = new HashMap<PosTaggedToken, List<PosTaggedToken>>(); this.dependentMap = new HashMap<PosTaggedToken, List<PosTaggedToken>>(); this.dependentByLabelMap = new HashMap<PosTaggedToken, Map<String, List<PosTaggedToken>>>(); Map<PosTaggedToken, Set<PosTaggedToken>> leftDependentSetMap = new HashMap<PosTaggedToken, Set<PosTaggedToken>>(); Map<PosTaggedToken, Set<PosTaggedToken>> rightDependentSetMap = new HashMap<PosTaggedToken, Set<PosTaggedToken>>(); Map<PosTaggedToken, Map<String, Set<PosTaggedToken>>> dependentSetByLabelMap = new HashMap<PosTaggedToken, Map<String, Set<PosTaggedToken>>>(); for (DependencyArc arc : this.dependencies) { this.governingDependencyMap.put(arc.getDependent(), arc); Map<PosTaggedToken, Set<PosTaggedToken>> dependentMap = null; if (arc.getDependent().getToken().getIndex() < arc.getHead().getToken().getIndex()) dependentMap = leftDependentSetMap; else dependentMap = rightDependentSetMap; Set<PosTaggedToken> dependents = dependentMap.get(arc.getHead()); if (dependents == null) { dependents = new TreeSet<PosTaggedToken>(new PosTaggedTokenLeftToRightComparator()); dependentMap.put(arc.getHead(), dependents); } dependents.add(arc.getDependent()); Map<String, Set<PosTaggedToken>> labelMap = dependentSetByLabelMap.get(arc.getHead()); if (labelMap == null) { labelMap = new HashMap<String, Set<PosTaggedToken>>(); dependentSetByLabelMap.put(arc.getHead(), labelMap); } Set<PosTaggedToken> dependentsByLabel = labelMap.get(arc.getLabel()); if (dependentsByLabel == null) { dependentsByLabel = new TreeSet<PosTaggedToken>(new PosTaggedTokenLeftToRightComparator()); labelMap.put(arc.getLabel(), dependentsByLabel); } dependentsByLabel.add(arc.getDependent()); } for (PosTaggedToken head : leftDependentSetMap.keySet()) { List<PosTaggedToken> leftDeps = new ArrayList<PosTaggedToken>(leftDependentSetMap.get(head)); this.leftDependentMap.put(head, leftDeps); } for (PosTaggedToken head : rightDependentSetMap.keySet()) { List<PosTaggedToken> rightDeps = new ArrayList<PosTaggedToken>(rightDependentSetMap.get(head)); this.rightDependentMap.put(head, rightDeps); } for (PosTaggedToken head : this.getPosTagSequence()) { List<PosTaggedToken> leftDeps = this.leftDependentMap.get(head); if (leftDeps == null) { leftDeps = new ArrayList<PosTaggedToken>(0); this.leftDependentMap.put(head, leftDeps); } List<PosTaggedToken> rightDeps = this.rightDependentMap.get(head); if (rightDeps == null) { rightDeps = new ArrayList<PosTaggedToken>(0); this.rightDependentMap.put(head, rightDeps); } List<PosTaggedToken> allDeps = new ArrayList<PosTaggedToken>(leftDeps.size() + rightDeps.size()); allDeps.addAll(leftDeps); allDeps.addAll(rightDeps); this.dependentMap.put(head, allDeps); } for (PosTaggedToken head : dependentSetByLabelMap.keySet()) { Map<String, Set<PosTaggedToken>> depSetMap = dependentSetByLabelMap.get(head); Map<String, List<PosTaggedToken>> labelMap = new HashMap<String, List<PosTaggedToken>>( depSetMap.size()); this.dependentByLabelMap.put(head, labelMap); for (String label : depSetMap.keySet()) { List<PosTaggedToken> deps = new ArrayList<PosTaggedToken>(depSetMap.get(label)); labelMap.put(label, deps); } } } } @Override public DependencyArc addDependency(PosTaggedToken head, PosTaggedToken dependent, String label, Transition transition) { DependencyArc arc = this.parserServiceInternal.getDependencyArc(head, dependent, label); this.addDependency(arc); this.dependentTransitionMap.put(dependent, transition); // calculate probability based on decisions if (LOG.isTraceEnabled()) LOG.trace("Prob for " + arc.toString()); double probLog = 0.0; int numDecisions = 0; for (int i = lastProbApplied; i < this.decisions.size(); i++) { Decision<Transition> decision = decisions.get(i); probLog += decision.getProbabilityLog(); if (LOG.isTraceEnabled()) { LOG.trace(decision.getOutcome() + ", *= " + decision.getProbability()); } numDecisions++; } if (useGeometricMeanForProbs) { if (numDecisions > 0) probLog /= numDecisions; } arc.setProbability(Math.exp(probLog)); this.lastProbApplied = this.decisions.size(); if (LOG.isTraceEnabled()) LOG.trace("prob=" + arc.getProbability()); return arc; } void addDependency(DependencyArc arc) { PosTaggedToken ancestor = arc.getHead(); while (ancestor != null) { if (ancestor.equals(arc.getDependent())) { throw new CircularDependencyException(this, arc.getHead(), arc.getDependent()); } ancestor = this.getHead(ancestor); } this.dependencies.add(arc); // force update of dependency maps this.governingDependencyMap = null; } public ParserServiceInternal getParserServiceInternal() { return parserServiceInternal; } public void setParserServiceInternal(ParserServiceInternal parserServiceInternal) { this.parserServiceInternal = parserServiceInternal; } @Override public String toString() { StringBuilder sb = new StringBuilder(); Iterator<PosTaggedToken> stackIterator = this.stack.iterator(); if (stackIterator.hasNext()) sb.insert(0, stackIterator.next().toString()); if (stackIterator.hasNext()) sb.insert(0, stackIterator.next().toString() + ","); if (stackIterator.hasNext()) sb.insert(0, stackIterator.next().toString() + ","); if (stackIterator.hasNext()) sb.insert(0, "...,"); sb.insert(0, "Stack["); sb.append("]. Buffer["); Iterator<PosTaggedToken> bufferIterator = this.buffer.iterator(); if (bufferIterator.hasNext()) sb.append(bufferIterator.next().toString()); if (bufferIterator.hasNext()) sb.append("," + bufferIterator.next().toString()); if (bufferIterator.hasNext()) sb.append("," + bufferIterator.next().toString()); if (bufferIterator.hasNext()) sb.append(",..."); sb.append("]"); sb.append(" Deps["); for (DependencyArc arc : this.dependencies) { sb.append(arc.toString() + ","); } sb.append("]"); return sb.toString(); } public Map<PosTaggedToken, Transition> getDependentTransitionMap() { return dependentTransitionMap; } @Override public DependencyNode getParseTree() { if (parseTree == null) { PosTaggedToken root = null; for (PosTaggedToken token : this.posTagSequence) { if (token.getTag().equals(PosTag.ROOT_POS_TAG)) { root = token; break; } } parseTree = this.parserServiceInternal.getDependencyNode(root, "", this); parseTree.autoPopulate(); } return parseTree; } @Override public DependencyNode getDetachedDependencyNode(PosTaggedToken posTaggedToken) { DependencyArc arc = this.getGoverningDependency(posTaggedToken); DependencyNode node = this.parserServiceInternal.getDependencyNode(posTaggedToken, arc.getLabel(), this); return node; } @Override public List<Decision<Transition>> getDecisions() { return decisions; } @Override public List<Solution> getUnderlyingSolutions() { return underlyingSolutions; } @Override public void addDecision(Decision<Transition> decision) { this.decisions.add(decision); } @SuppressWarnings("rawtypes") public ScoringStrategy getScoringStrategy() { return scoringStrategy; } public void setScoringStrategy(@SuppressWarnings("rawtypes") ScoringStrategy scoringStrategy) { this.scoringStrategy = scoringStrategy; } @SuppressWarnings("unchecked") @Override public <T, Y> FeatureResult<Y> getResultFromCache(Feature<T, Y> feature, RuntimeEnvironment env) { FeatureResult<Y> result = null; String key = feature.getName() + env.getKey(); if (this.featureCache.containsKey(key)) { result = (FeatureResult<Y>) this.featureCache.get(key); } return result; } @Override public <T, Y> void putResultInCache(Feature<T, Y> feature, FeatureResult<Y> featureResult, RuntimeEnvironment env) { String key = feature.getName() + env.getKey(); this.featureCache.put(key, featureResult); } @Override public ParseConfiguration getParseConfiguration() { return this; } public void clearMemory() { this.governingDependencyMap = null; this.rightDependentMap = null; this.leftDependentMap = null; } @Override public Sentence getSentence() { return this.getPosTagSequence().getTokenSequence().getSentence(); } @Override public List<List<FeatureResult<?>>> getIncrementalFeatureResults() { return incrementalFeatureResults; } @Override public boolean canReach(RankingSolution correctSolution) { if (correctSolution instanceof ParseConfiguration) { ParseConfiguration configuration = (ParseConfiguration) correctSolution; if (configuration.getTransitions().size() < this.getTransitions().size()) { return false; } for (int i = 0; i < this.getTransitions().size(); i++) { Transition myTransition = this.getTransitions().get(i); Transition hisTransition = configuration.getTransitions().get(i); if (!myTransition.getCode().equals(hisTransition.getCode())) { return false; } } return true; } else { return false; } } public long getCreateDate() { return createDate; } @Override public List<String> getIncrementalOutcomes() { List<String> outcomes = new ArrayList<String>(); for (Transition transition : this.transitions) { outcomes.add(transition.getCode()); } return outcomes; } /** * True: use a geometric mean when calculating individual arc probabilities * (which multiply the probabilities for the transitions since the last arc was added). * False: use the simple product. Default is true. * @return */ public boolean isUseGeometricMeanForProbs() { return useGeometricMeanForProbs; } public void setUseGeometricMeanForProbs(boolean useGeometricMeanForProbs) { this.useGeometricMeanForProbs = useGeometricMeanForProbs; } public int getLastProbApplied() { return lastProbApplied; } }