Java tutorial
/** * Copyright (c) 2014, the LESK-WSD-DSM AUTHORS. * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * Neither the name of the University of Bari nor the names of its contributors * may be used to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * * GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 * */ package di.uniba.it.wsd; import di.uniba.it.wsd.data.ExecuteStatistics; import di.uniba.it.wsd.data.POSenum; import di.uniba.it.wsd.data.RelatedSynset; import di.uniba.it.wsd.data.SynsetOut; import di.uniba.it.wsd.data.Token; import di.uniba.it.wsd.dsm.ObjectVector; import di.uniba.it.wsd.dsm.VectorStore; import di.uniba.it.wsd.dsm.VectorUtils; import edu.mit.jwi.item.IPointer; import edu.mit.jwi.item.POS; import it.uniroma1.lcl.babelnet.BabelGloss; import it.uniroma1.lcl.babelnet.BabelNet; import it.uniroma1.lcl.babelnet.BabelSense; import it.uniroma1.lcl.babelnet.BabelSenseSource; import it.uniroma1.lcl.babelnet.BabelSynset; import it.uniroma1.lcl.jlt.util.Language; import java.io.IOException; import java.io.StringReader; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import org.apache.commons.lang.StringUtils; import org.apache.commons.math.stat.StatUtils; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.analysis.tokenattributes.TermAttribute; import org.apache.lucene.util.Version; import org.tartarus.snowball.SnowballStemmer; import org.tartarus.snowball.ext.frenchStemmer; import org.tartarus.snowball.ext.germanStemmer; import org.tartarus.snowball.ext.italianStemmer; import org.tartarus.snowball.ext.porterStemmer; import org.tartarus.snowball.ext.spanishStemmer; /** * This class implements the Word Sense Disambiguation algorithm * * @author Pierpaolo Basile pierpaolo.basile@gmail.com */ public class RevisedLesk { /** * Constant for WordNet output format */ public static final int OUT_WORDNET = 1000; /** * Constant for BabelNet output format */ public static final int OUT_BABELNET = 1100; /** * Constant for Synset Distribution function */ public static final int SD_PROB = 2000; /** * Constant for Synset Distribution function */ public static final int SD_PROB_CROSS = 2100; /** * Constant for Synset Distribution function */ public static final int SD_OCC = 2200; /** * Constant for Wiki scoring function */ public static final int WIKI_LEV = 3000; /** * Constant for Wiki scoring function */ public static final int WIKI_UNI = 3100; private int outType = OUT_BABELNET; private int sdType = SD_PROB; private int wikiType = WIKI_LEV; private double weightWsd = 0.5; private double weightSd = 0.5; private BabelNet babelNet; private int contextSize = 5; private Language language; private boolean stemming = false; private VectorStore dsm; private int maxDepth = 1; private SenseFreqAPI senseFreq; private boolean scoreGloss = true; private ExecuteStatistics execStats = new ExecuteStatistics(); private static final Logger logger = Logger.getLogger(RevisedLesk.class.getName()); /** * * @param language */ public RevisedLesk(Language language) { this.language = language; } /** * * @param language * @param dsm */ public RevisedLesk(Language language, VectorStore dsm) { this.language = language; this.dsm = dsm; } /** * * @return */ public int getSdType() { return sdType; } /** * * @param sdType */ public void setSdType(int sdType) { this.sdType = sdType; } /** * * @return */ public int getWikiType() { return wikiType; } /** * * @param wikiType */ public void setWikiType(int wikiType) { this.wikiType = wikiType; } /** * * @return */ public double getWeightWsd() { return weightWsd; } /** * * @param weightWsd */ public void setWeightWsd(double weightWsd) { this.weightWsd = weightWsd; } /** * * @return */ public double getWeightSd() { return weightSd; } /** * * @param weightSd */ public void setWeightSd(double weightSd) { this.weightSd = weightSd; } /** * * @return */ public SenseFreqAPI getSenseFreq() { return senseFreq; } /** * * @param senseFreq */ public void setSenseFreq(SenseFreqAPI senseFreq) { this.senseFreq = senseFreq; } /** * * @return */ public Language getLanguage() { return language; } /** * * @param language */ public void setLanguage(Language language) { this.language = language; } /** * * @return */ public int getContextSize() { return contextSize; } /** * * @param contextSize */ public void setContextSize(int contextSize) { this.contextSize = contextSize; } /** * * @return */ public VectorStore getDsm() { return dsm; } /** * */ public void init() { babelNet = BabelNet.getInstance(); logger.log(Level.INFO, "Language: {0}", this.language); logger.log(Level.INFO, "Context size: {0}", this.contextSize); logger.log(Level.INFO, "Depth: {0}", this.maxDepth); logger.log(Level.INFO, "Stemmig: {0}", this.stemming); logger.log(Level.INFO, "Score gloss: {0}", this.scoreGloss); logger.log(Level.INFO, "Weigths, wsd: {0}, synset distr. {1}", new Object[] { weightWsd, weightSd }); if (this.senseFreq != null) { logger.info("Sense distribution ENABLED"); if (sdType == SD_OCC) { logger.info("Sense distribution type=occurrences"); } else if (sdType == SD_PROB) { logger.info("Sense distribution type=probability"); } else if (sdType == SD_PROB_CROSS) { logger.info("Sense distribution type=cross probability"); } if (wikiType == WIKI_LEV) { logger.info("Wiki synset scoring: LEV"); } else if (wikiType == WIKI_UNI) { logger.info("Wiki synset scoring: Uniform"); } } else { logger.info("Sense distribution DISABLED"); } if (this.dsm != null) { logger.info("Distributional Semantic Model ENABLED"); } else { logger.info("Distributional Semantic Model DISABLED"); } } /** * */ public void close() { } private SnowballStemmer getStemmer(Language language) { if (language.equals(Language.EN)) { return new porterStemmer(); } else if (language.equals(Language.ES)) { return new spanishStemmer(); } else if (language.equals(Language.FR)) { return new frenchStemmer(); } else if (language.equals(Language.DE)) { return new germanStemmer(); } else if (language.equals(Language.IT)) { return new italianStemmer(); } else { return null; } } /** * * @param text * @return * @throws IOException */ public Map<String, Float> buildBag(String text) throws IOException { Map<String, Float> bag = new HashMap<>(); Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_CURRENT); SnowballStemmer stemmer = null; if (stemming) { stemmer = getStemmer(language); if (stemmer == null) { Logger.getLogger(RevisedLesk.class.getName()).log(Level.WARNING, "No stemmer for language {0}", language); } } TokenStream tokenStream = analyzer.tokenStream("gloss", new StringReader(text)); while (tokenStream.incrementToken()) { TermAttribute token = (TermAttribute) tokenStream.getAttribute(TermAttribute.class); String term = token.term(); if (stemmer != null) { stemmer.setCurrent(term); if (stemmer.stem()) { term = stemmer.getCurrent(); } } Float c = bag.get(term); if (c == null) { bag.put(term, 1f); } else { bag.put(term, c + 1f); } } return bag; } private Map<String, Float> buildContext(List<Token> sentence, int pivot) throws Exception { int i = pivot - 1; int c = 0; StringBuilder sb = new StringBuilder(); while (i >= 0 && c < contextSize) { if (sentence.get(i).getPos() != POSenum.OTHER) { sb.append(sentence.get(i).getToken()); sb.append(" "); c++; } i--; } i = pivot + 1; c = 0; while (i < sentence.size() && c < contextSize) { if (sentence.get(i).getPos() != POSenum.OTHER) { sb.append(sentence.get(i).getToken()); sb.append(" "); c++; } i++; } return buildBag(sb.toString()); } private void getRelatedSynsets(Map<BabelSynset, RelatedSynset> map, int distance) throws IOException { List<BabelSynset> listKey = new ArrayList<>(map.keySet()); for (BabelSynset synset : listKey) { RelatedSynset get = map.get(synset); if (!get.isVisited()) { get.setVisited(true); Map<IPointer, List<BabelSynset>> relatedMap = synset.getRelatedMap(); Iterator<IPointer> itRel = relatedMap.keySet().iterator(); while (itRel.hasNext()) { IPointer pointer = itRel.next(); if (!pointer.getName().equalsIgnoreCase("antonym")) { List<BabelSynset> list = relatedMap.get(pointer); for (BabelSynset relSynset : list) { RelatedSynset rs = map.get(relSynset); if (rs == null) { map.put(relSynset, new RelatedSynset(relSynset, distance)); } } } } } } } private Map<String, Float> buildGlossBag(BabelSynset synset) throws IOException { Map<BabelSynset, RelatedSynset> relatedMap = new HashMap<>(); relatedMap.put(synset, new RelatedSynset(synset, 0)); for (int i = 0; i < maxDepth; i++) { getRelatedSynsets(relatedMap, i + 1); } Iterator<BabelSynset> itRel = relatedMap.keySet().iterator(); Map<String, Float> bag = new HashMap<>(); while (itRel.hasNext()) { BabelSynset relSynset = itRel.next(); RelatedSynset rs = relatedMap.get(relSynset); List<BabelGloss> glosses = relSynset.getGlosses(language); List<String> glossesToProcess = new ArrayList<>(); execStats.incrementTotalGloss(); if (glosses.isEmpty()) { logger.log(Level.FINEST, "No gloss for synset: {0}", relSynset); execStats.incrementNoGloss(); List<BabelSense> senses = relSynset.getSenses(this.language); StringBuilder sb = new StringBuilder(); for (BabelSense bs : senses) { sb.append(bs.getLemma().replace("_", " ")).append(" "); } glossesToProcess.add(sb.toString()); } else { for (BabelGloss gloss : glosses) { glossesToProcess.add(gloss.getGloss()); } } float df = maxDepth + 1 - rs.getDistance(); for (String gloss : glossesToProcess) { Map<String, Float> gbag = buildBag(gloss); Iterator<String> iterator = gbag.keySet().iterator(); while (iterator.hasNext()) { String term = iterator.next(); Float c = bag.get(term); if (c == null) { bag.put(term, df * gbag.get(term)); } else { bag.put(term, c + df * gbag.get(term)); } } } } return bag; } private float gf(List<Map<String, Float>> mapList, String key) { float gf = 0; for (Map<String, Float> map : mapList) { if (map.containsKey(key)) { gf++; } } return gf; } private List<Map<String, Float>> buildGlossBag(List<BabelSense> senses) throws IOException { List<Map<String, Float>> mapList = new ArrayList<>(); for (BabelSense sense : senses) { mapList.add(buildGlossBag(sense.getSynset())); } if (scoreGloss) { for (Map<String, Float> map : mapList) { Iterator<String> iterator = map.keySet().iterator(); while (iterator.hasNext()) { String key = iterator.next(); float igf = 1 + (float) (Math.log(senses.size() / gf(mapList, key)) / Math.log(2)); float gf = map.get(key); map.put(key, gf * igf); } } } return mapList; } private double simBag(Map<String, Float> bag1, Map<String, Float> bag2) { double n1 = 0; double n2 = 0; double ip = 0; Iterator<String> it1 = bag1.keySet().iterator(); while (it1.hasNext()) { String t1 = it1.next(); Float v1 = bag1.get(t1); if (bag2.containsKey(t1)) { ip += v1.doubleValue() * bag2.get(t1).doubleValue(); } n1 += Math.pow(v1.doubleValue(), 2); } Iterator<Float> it2 = bag2.values().iterator(); while (it2.hasNext()) { n2 += Math.pow(it2.next().doubleValue(), 2); } if (n1 == 0 || n2 == 0) { return 0; } else { return ip / (n1 * n2); } } private float[] buildVector(Map<String, Float> bag, boolean normalize) { Iterator<String> it = bag.keySet().iterator(); float[] bagv = new float[ObjectVector.vecLength]; while (it.hasNext()) { String t1 = it.next(); float w = bag.get(t1); float[] v = dsm.getVector(t1); execStats.incrementTotalSVhit(); if (v != null) { for (int k = 0; k < v.length; k++) { bagv[k] += w * v[k]; } } else { execStats.incrementMissSV(); } } if (normalize && !VectorUtils.isZeroVector(bagv)) { bagv = VectorUtils.getNormalizedVector(bagv); } return bagv; } private double simVector(Map<String, Float> bag1, Map<String, Float> bag2) { float[] gv1 = buildVector(bag1, true); float[] gv2 = buildVector(bag2, true); return VectorUtils.scalarProduct(gv1, gv2); } private double sim(Map<String, Float> bag1, Map<String, Float> bag2) { if (dsm != null) { return simVector(bag1, bag2); } else { return simBag(bag1, bag2); } } private List<BabelSense> lookupSense(Language language, String lemma, POS postag) throws IOException { List<BabelSense> senses = babelNet.getSenses(language, lemma, postag, BabelSenseSource.WN); if (senses == null || senses.isEmpty()) { senses = babelNet.getSenses(language, lemma.replace(" ", "_"), postag, BabelSenseSource.WN); } if (senses == null || senses.isEmpty()) { senses = babelNet.getSenses(language, lemma, postag, BabelSenseSource.WNTR); } if (senses == null || senses.isEmpty()) { senses = babelNet.getSenses(language, lemma.replace(" ", "_"), postag, BabelSenseSource.WNTR); } if (senses == null || senses.isEmpty()) { senses = babelNet.getSenses(language, lemma, postag); } if (senses == null || senses.isEmpty()) { senses = babelNet.getSenses(language, lemma.replace(" ", "_"), postag); } /* if (senses == null || senses.isEmpty()) { senses = babelNet.getSenses(language, lemma); } if (senses == null || senses.isEmpty()) { senses = babelNet.getSenses(language, lemma.replace(" ", "_")); } */ if (senses == null || senses.isEmpty()) { Logger.getLogger(RevisedLesk.class.getName()).log(Level.WARNING, "No senses for {0}, pos-tag {1}", new Object[] { lemma, postag }); } //remove duplicate senses if (senses != null && !senses.isEmpty()) { Set<String> ids = new HashSet<>(); for (int i = senses.size() - 1; i >= 0; i--) { if (!ids.add(senses.get(i).getSynset().getId())) { senses.remove(i); } } } return senses; } private String convertPosEnum(POSenum pos) { if (pos == POSenum.NOUN) { return "n"; } else if (pos == POSenum.VERB) { return "v"; } else if (pos == POSenum.ADJ) { return "a"; } else if (pos == POSenum.ADV) { return "r"; } else { return "o"; } } /** * * @param sentence * @throws Exception */ public void disambiguate(List<Token> sentence) throws Exception { execStats = new ExecuteStatistics(); System.out.println(); for (int i = 0; i < sentence.size(); i++) { System.out.print("."); Token token = sentence.get(i); if (token.isToDisambiguate() && token.getSynsetList().isEmpty()) { Map<String, Float> contextBag = buildContext(sentence, i); if (token.getPos() != POSenum.OTHER) { List<BabelSense> senses = null; if (token.getPos() == POSenum.NOUN) { senses = lookupSense(language, token.getLemma(), POS.NOUN); } else if (token.getPos() == POSenum.VERB) { senses = lookupSense(language, token.getLemma(), POS.VERB); } else if (token.getPos() == POSenum.ADJ) { senses = lookupSense(language, token.getLemma(), POS.ADJECTIVE); } else if (token.getPos() == POSenum.ADV) { senses = lookupSense(language, token.getLemma(), POS.ADVERB); } if (senses != null) { float[] as = null; if (senseFreq != null && sdType == SD_OCC) { as = senseFreq.getOccurrencesArray(senses); } List<Map<String, Float>> buildGlossBag = buildGlossBag(senses); for (int j = 0; j < senses.size(); j++) { double sim = 0; Map<String, Float> bag = buildGlossBag.get(j); //assign WSD algorithm score sim = sim(contextBag, bag); if (senseFreq != null) { if (sdType == SD_PROB) { //sense distribution based on conditional probability p(si|w) if (language.equals(Language.EN)) { //English is the WordNet native language if (senses.get(j).getWordNetOffset() != null && senses.get(j).getWordNetOffset().length() > 0) { String lemmakey = token.getLemma() + "#" + convertPosEnum(token.getPos()); float freq = senseFreq.getSynsetProbability(lemmakey, senses.get(j).getWordNetOffset(), senses.size()); sim = weightWsd * sim + weightSd * freq; } else { if (wikiType == WIKI_LEV) { sim = weightWsd * sim + weightSd * computeLDscore(token.getToken(), senses.get(j).getLemma()); } else { sim = weightWsd * sim + weightSd * (1 / (double) senses.size()); } } } else { //Translate WordNet synset from other language String mainSense = senses.get(j).getSynset().getMainSense(); if (mainSense != null && !mainSense.startsWith("WIKI:") && mainSense.length() > 0) { String lemmakey = token.getLemma() + "#" + convertPosEnum(token.getPos()); float maxFreq = senseFreq.getMaxSenseProbability(lemmakey, senses.get(j), senses.size()); sim = weightWsd * sim + weightSd * maxFreq; } else { if (wikiType == WIKI_LEV) { sim = weightWsd * sim + weightSd * computeLDscore(token.getToken(), senses.get(j).getLemma()); } else { sim = weightWsd * sim + weightSd * (1 / (double) senses.size()); } } } } else if (sdType == SD_PROB_CROSS) { //Use english sense distribution for other langauge String mainSense = senses.get(j).getSynset().getMainSense(); if (mainSense != null && !mainSense.startsWith("WIKI:") && mainSense.length() > 0) { int si = mainSense.lastIndexOf("#"); String lemmakey = mainSense.substring(0, si); float maxFreq = senseFreq.getMaxSenseProbability(lemmakey, senses.get(j), senses.size()); sim = weightWsd * sim + weightSd * maxFreq; } else { if (wikiType == WIKI_LEV) { sim = weightWsd * sim + weightSd * computeLDscore(token.getToken(), senses.get(j).getLemma()); } else { sim = weightWsd * sim + weightSd * (1 / (double) senses.size()); } } } else if (sdType == SD_OCC) { //language independent based on synset offset sim = weightWsd * sim + weightSd * as[j]; } } if (outType == OUT_BABELNET) { token.getSynsetList().add(new SynsetOut(senses.get(j).getSynset().getId(), sim)); } else if (outType == OUT_WORDNET) { token.getSynsetList().add(new SynsetOut(senses.get(j).getSensekey(), sim)); } else { throw new Exception("Output type not valid: " + outType); } } Collections.sort(token.getSynsetList()); //logger.log(Level.FINEST, "{0}\tmean: {1}\tvariance: {2}", new Object[]{token.toString(), getMean(token.getSynsetList()), getVariance(token.getSynsetList())}); } } } } } private float computeLDscore(String s1, String s2) { float maxLength = (float) Math.max(s1.length(), s2.length()); float ld = (float) StringUtils.getLevenshteinDistance(s1, s2); return 1 - ld / maxLength; } /** * * @param list * @return */ public double getMean(List<SynsetOut> list) { double[] scores = new double[list.size()]; int l = 0; for (SynsetOut out : list) { scores[l] = out.getScore(); l++; } return StatUtils.mean(scores); } /** * * @param list * @return */ public double getVariance(List<SynsetOut> list) { double[] scores = new double[list.size()]; int l = 0; for (SynsetOut out : list) { scores[l] = out.getScore(); l++; } return StatUtils.variance(scores); } /** * * @return */ public BabelNet getBabelNet() { return babelNet; } /** * @return the stemming */ public boolean isStemming() { return stemming; } /** * @param stemming the stemming to set */ public void setStemming(boolean stemming) { this.stemming = stemming; } /** * @return the outType */ public int getOutType() { return outType; } /** * @param outType the outType to set */ public void setOutType(int outType) { this.outType = outType; } /** * * @return */ public int getMaxDepth() { return maxDepth; } /** * * @param maxDepth */ public void setMaxDepth(int maxDepth) { this.maxDepth = maxDepth; } /** * * @return */ public boolean isScoreGloss() { return scoreGloss; } /** * * @param scoreGloss */ public void setScoreGloss(boolean scoreGloss) { this.scoreGloss = scoreGloss; } /** * * @return */ public ExecuteStatistics getExecStats() { return execStats; } }