eu.project.ttc.models.ContextVector.java Source code

Java tutorial

Introduction

Here is the source code for eu.project.ttc.models.ContextVector.java

Source

/*******************************************************************************
 * Copyright 2015-2016 - CNRS (Centre National de Recherche Scientifique)
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *******************************************************************************/
package eu.project.ttc.models;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang3.mutable.MutableInt;

import com.google.common.base.Joiner;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import eu.project.ttc.metrics.AssociationRate;

/**
 * 
 * A sorted Term/Frequency map used as a context vector during alignment operations.
 * 
 * @author Damien Cram
 *
 */
public class ContextVector {
    private static final String MSG_TERM_CLASS_NULL = "Flag useTermClasses is set to true but term classes are not built.";
    private Map<Term, Entry> entries = Maps.newHashMap();
    private List<Entry> _sortedEntries = null;
    private int totalCooccurrences = 0;

    private boolean useTermClasses = false;

    private Term term;

    /**
     * Default constructor for {@link ContextVector}. Must be used if no normalization is required.
     */
    public ContextVector() {
    }

    /**
     * 
     * Construct a context vector with a back reference to its owner term.
     * The parent term is useful for normalization.
     * 
     * @see ContextVector#toAssocRateVector(CrossTable, AssociationRate, boolean)
     * @param term
     *          the owner term
     * @param useTermClasses
     *          if <code>true</code>, builds this context vector with term 
     *          classes instead of terms 
     */
    public ContextVector(Term term, boolean useTermClasses) {
        this(term);
        this.useTermClasses = useTermClasses;
    }

    /**
     * 
     * Construct a context vector with a back reference to its owner term.
     * The parent term is useful for normalization.
     * 
     * @see ContextVector#toAssocRateVector(CrossTable, AssociationRate, boolean)
     * @param term
     *          the owner term
     */
    public ContextVector(Term term) {
        this.term = term;
    }

    public void addAllCooccurrences(Iterator<TermOccurrence> it) {
        while (it.hasNext()) {
            addCooccurrence(it.next());
        }
    }

    public void addCooccurrence(TermOccurrence occ) {
        Term term2 = getTermToAdd(occ.getTerm());
        if (!entries.containsKey(term2))
            entries.put(term2, new Entry(term2));
        entries.get(term2).increment();
        this.totalCooccurrences++;
        setDirty();
    }

    private Term getTermToAdd(Term t) {
        Term termToAdd = t;
        if (this.useTermClasses) {
            Preconditions.checkState(termToAdd.getTermClass() != null, MSG_TERM_CLASS_NULL);
            termToAdd = termToAdd.getTermClass().getHead();
        }
        return termToAdd;
    }

    public void addEntry(Term coTerm, int nbCooccs, double assocRate) {
        Term termToAdd = getTermToAdd(coTerm);

        if (entries.containsKey(termToAdd))
            this.totalCooccurrences -= entries.get(termToAdd).getNbCooccs();
        entries.put(termToAdd, new Entry(termToAdd, nbCooccs, assocRate));
        this.totalCooccurrences += nbCooccs;
        setDirty();
    }

    public void removeCoTerm(Term term) {
        Term termToAdd = getTermToAdd(term);
        if (this.entries.containsKey(termToAdd)) {
            Entry e = this.entries.remove(termToAdd);
            this.totalCooccurrences -= e.getNbCooccs();
        }
        setDirty();
    }

    private void setDirty() {
        this._sortedEntries = null;
    }

    /**
     * Gives the list of terms in this context vector sorted by frequency
     * (most frequent first)
     * 
     * @return
     */
    public List<Entry> getEntries() {
        if (_sortedEntries == null) {
            this._sortedEntries = Lists.newArrayListWithCapacity(entries.size());
            for (Map.Entry<Term, Entry> e : entries.entrySet())
                this._sortedEntries.add(e.getValue());
            Collections.sort(this._sortedEntries);
        }
        return _sortedEntries;
    }

    /**
     * Normalizes all <code>assocRates</code> so that their sum is 1
     */
    public void normalize() {
        double sum = 0;
        for (Entry e : entries.values())
            sum += e.getAssocRate();
        if (sum != 0) {
            for (Entry e : entries.values())
                e.setAssocRate(e.getAssocRate() / sum);
        }
    }

    public class Entry implements Comparable<Entry> {
        private static final String STRING_FORMAT = "%s: %.3f (%d)";
        private Term coTerm;
        private MutableInt nbCooccs;
        private double assocRate;

        private Entry(Term term) {
            this(term, 0, 0d);
        }

        private Entry(Term coTerm, int nbCoccurrences, double assocRate) {
            this.coTerm = coTerm;
            this.nbCooccs = new MutableInt(nbCoccurrences);
            this.assocRate = assocRate;
        }

        private void increment() {
            this.nbCooccs.increment();
        }

        public void setAssocRate(double assocRate) {
            this.assocRate = assocRate;
        }

        public double getAssocRate() {
            return assocRate;
        }

        public Term getCoTerm() {
            return coTerm;
        }

        public int getNbCooccs() {
            return nbCooccs.intValue();
        }

        @Override
        public boolean equals(Object obj) {
            if (obj instanceof Entry)
                return Objects.equal(((Entry) obj).coTerm, this.coTerm)
                        && Objects.equal(((Entry) obj).nbCooccs, this.nbCooccs)
                        && Objects.equal(((Entry) obj).assocRate, this.assocRate);
            else
                return false;
        }

        @Override
        public int hashCode() {
            return this.coTerm.getId();
        }

        @Override
        public int compareTo(Entry o) {
            return ComparisonChain.start().compare(o.assocRate, this.assocRate).compare(o.nbCooccs, this.nbCooccs)
                    .result();
        }

        @Override
        public String toString() {
            return String.format(STRING_FORMAT, coTerm.getLemma(), assocRate, nbCooccs.intValue());
        }

    }

    public int getNbCooccs(Term term) {
        return entries.containsKey(term) ? entries.get(term).getNbCooccs() : 0;
    }

    public double getAssocRate(Term term) {
        return entries.containsKey(term) ? entries.get(term).getAssocRate() : 0d;
    }

    public Set<Term> terms() {
        return entries.keySet();
    }

    @Override
    public int hashCode() {
        return Objects.hashCode(this.term);
    }

    @Override
    public boolean equals(Object obj) {
        if (obj instanceof ContextVector) {
            ContextVector v = (ContextVector) obj;
            return Objects.equal(this.term, v.term) && Iterables.elementsEqual(this.getEntries(), v.getEntries());
        } else
            return false;
    }

    public Term getTerm() {
        return term;
    }

    /**
     * Normalize this vector according to a cross table
     * and an association rate measure.
     * 
     * This method recomputes all <code>{@link Entry}.frequency</code> values
     * with the normalized ones.
     * 
     * @param table
     *          the pre-computed co-occurrences {@link CrossTable}
     * @param assocRateFunction
     *          the {@link AssociationRate} measure implementation
     * @param normalize 
     */
    public void toAssocRateVector(CrossTable table, AssociationRate assocRateFunction, boolean normalize) {
        double assocRate;
        for (Term coterm : entries.keySet()) {
            assocRate = table.computeRate(assocRateFunction, this.term, coterm);
            entries.get(coterm).setAssocRate(assocRate);
        }
        if (normalize)
            normalize();

        setDirty();
    }

    private final static String STRING_FORMAT = "<%s> {%s}";

    @Override
    public String toString() {
        return String.format(STRING_FORMAT, this.term == null ? "no parent term" : this.term.getLemma(),
                Joiner.on(", ").join(getEntries()));
    }

    public int getTotalCoccurrences() {
        return this.totalCooccurrences;
    }
}