de.tudarmstadt.ukp.dkpro.bigdata.collocations.AssocReducer.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.bigdata.collocations.AssocReducer.java

Source

package de.tudarmstadt.ukp.dkpro.bigdata.collocations;

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import java.io.IOException;
import java.lang.reflect.Method;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs;
import org.apache.mahout.math.stats.LogLikelihood;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import de.tudarmstadt.ukp.dkpro.bigdata.collocations.CollocMapper.Count;

/**
 * Reducer for pass 2 of the collocation discovery job. Collects ngram and
 * sub-ngram frequencies and performs the Log-likelihood ratio calculation.
 */
public class AssocReducer extends Reducer<Gram, Gram, Text, DoubleWritable> {

    /** Counter to track why a particlar entry was skipped */
    public enum Skipped {
        EXTRA_HEAD, EXTRA_TAIL, MISSING_HEAD, MISSING_TAIL, LESS_THAN_MIN_VALUE, LLR_CALCULATION_ERROR, CHI_CALCULATION_ERROR, PMI_CALCULATION_ERROR, DICE_CALCULATION_ERROR
    }

    private static final Logger log = LoggerFactory.getLogger(AssocReducer.class);

    public static final String NGRAM_TOTAL = "ngramTotal";
    public static final String MIN_VALUE = "minLLR";
    public static final float DEFAULT_MIN_VALUE = 0.1f;
    public static final String ASSOC_METRIC = "metric";

    public static final String DEFAULT_ASSOC = "llr";

    private long ngramTotal;
    private float minValue;
    private boolean emitUnigrams;
    private AssocCallback assocCalculator;
    private final AssocCallback llrCalculator = new ConcreteLLCallback();
    private final AssocCallback pmiCalculator = new PMICallback();
    private final AssocCallback chiCalculator = new ChiSquareCallback();
    private final AssocCallback diceCalculator = new DiceCallback();
    private Method metricMethod;
    private MultipleOutputs<?, ?> mos;
    AssociationMetrics ass = new AssociationMetrics();

    /**
     * Perform assoc calculation, input is: k:ngram:ngramFreq
     * v:(h_|t_)subgram:subgramfreq N = ngram total
     * 
     * Each ngram will have 2 subgrams, a head and a tail, referred to as A and
     * B respectively below.
     * 
     * A+ B: number of times a+b appear together: ngramFreq A+!B: number of
     * times A appears without B: hSubgramFreq - ngramFreq !A+ B: number of
     * times B appears without A: tSubgramFreq - ngramFreq !A+!B: number of
     * times neither A or B appears (in that order): N - (subgramFreqA +
     * subgramFreqB - ngramFreq)
     */
    @Override
    protected void reduce(Gram ngram, Iterable<Gram> values, Context context)
            throws IOException, InterruptedException {

        int[] gramFreq = { -1, -1 };

        int frequency = ngram.getFrequency();
        if (ngram.getType() == Gram.Type.UNIGRAM && emitUnigrams) {
            DoubleWritable dd = new DoubleWritable(frequency);
            Text t = new Text(ngram.getString());
            context.getCounter(Count.EMITTED_UNIGRAM).increment(1);
            context.write(t, dd);
            return;
        }
        // TODO better way to handle errors? Wouldn't an exception thrown here
        // cause hadoop to re-try the job?
        String[] gram = new String[2];
        for (Gram value : values) {

            int pos = value.getType() == Gram.Type.HEAD ? 0 : 1;

            if (gramFreq[pos] != -1) {
                log.warn("Extra {} for {}, skipping", value.getType(), ngram);
                if (value.getType() == Gram.Type.HEAD) {
                    context.getCounter(Skipped.EXTRA_HEAD).increment(1);
                } else {
                    context.getCounter(Skipped.EXTRA_TAIL).increment(1);
                }
                return;
            }

            gram[pos] = value.getString();
            gramFreq[pos] = value.getFrequency();
        }

        if (gramFreq[0] == -1) {
            log.warn("Missing head for {}, skipping.", ngram);
            context.getCounter(Skipped.MISSING_HEAD).increment(1);
            return;
        }
        if (gramFreq[1] == -1) {
            log.warn("Missing tail for {}, skipping", ngram);
            context.getCounter(Skipped.MISSING_TAIL).increment(1);
            return;
        }

        double value;
        // build continguency table
        long k11 = frequency; /* a&b */
        long k12 = gramFreq[0] - frequency; /* a&!b */
        long k21 = gramFreq[1] - frequency; /* !b&a */
        long k22 = ngramTotal - (gramFreq[0] + gramFreq[1] - frequency); /* !a&!b */

        try {

            value = assocCalculator.assoc(k11, k12, k21, k22);

        } catch (IllegalArgumentException ex) {
            context.getCounter(Skipped.LLR_CALCULATION_ERROR).increment(1);
            log.warn(
                    "Problem calculating assoc metric for ngram {}, HEAD {}:{}, TAIL {}:{}, k11/k12/k21/k22: {}/{}/{}/{}",
                    new Object[] { ngram, gram[0], gramFreq[0], gram[1], gramFreq[1] }, ex);
            return;
        }
        if (value < minValue) {
            context.getCounter(Skipped.LESS_THAN_MIN_VALUE).increment(1);
        } else {

            ass.init(k11, k12, k21, k22);
            // try {
            // Object invoke = metricMethod.invoke(value, gram);
            // } catch (IllegalArgumentException e1) {
            // // TODO Auto-generated catch block
            // e1.printStackTrace();
            // } catch (IllegalAccessException e1) {
            // // TODO Auto-generated catch block
            // e1.printStackTrace();
            // } catch (InvocationTargetException e1) {
            // // TODO Auto-generated catch block
            // e1.printStackTrace();
            // }
            mos.write("llr", new Text(ngram.getString()), new DoubleWritable(value));
            try {
                double pmi = ass.mutual_information();// pmiCalculator.assoc(k11,
                // k12, k21, k22);
                mos.write("pmi", new Text(ngram.getString()), new DoubleWritable(pmi));
            } catch (Exception e) {
                context.getCounter(Skipped.PMI_CALCULATION_ERROR).increment(1);

            }
            try {

                double chi = ass.chisquared();// chiCalculator.assoc(k11, k12,
                // k21, k22);
                mos.write("chi", new Text(ngram.getString()), new DoubleWritable(chi));
            } catch (Exception e) {
                context.getCounter(Skipped.CHI_CALCULATION_ERROR).increment(1);

            }
            try {

                double dice = ass.dice();// diceCalculator.assoc(k11, k12, k21,
                // k22);
                mos.write("dice", new Text(ngram.getString()), new DoubleWritable(dice));
            } catch (Exception e) {
                context.getCounter(Skipped.DICE_CALCULATION_ERROR).increment(1);

            }

            context.getCounter("assoctest", "EMITTED NGRAM").increment(1);

            mos.write("contingency", new Text(ngram.getString()),
                    new Text("" + k11 + "\t" + k12 + "\t" + k21 + "\t" + k22));

        }
    }

    @Override
    protected void cleanup(Context context) throws IOException, InterruptedException {
        mos.close();
    }

    @Override
    protected void setup(Context context) throws IOException, InterruptedException {
        super.setup(context);
        Configuration conf = context.getConfiguration();
        this.ngramTotal = conf.getLong(NGRAM_TOTAL, -1);
        this.minValue = conf.getFloat(MIN_VALUE, DEFAULT_MIN_VALUE);
        String assocType = conf.get(ASSOC_METRIC, DEFAULT_ASSOC);
        if (assocType.equalsIgnoreCase("llr"))
            assocCalculator = new ConcreteLLCallback();
        else if (assocType.equalsIgnoreCase("dice"))
            assocCalculator = new DiceCallback();
        else if (assocType.equalsIgnoreCase("pmi"))
            assocCalculator = new PMICallback();
        else if (assocType.equalsIgnoreCase("chi"))
            assocCalculator = new ChiSquareCallback();

        this.emitUnigrams = conf.getBoolean(CollocDriver.EMIT_UNIGRAMS, CollocDriver.DEFAULT_EMIT_UNIGRAMS);
        log.info("NGram Total: {}, Min DICE value: {}, Emit Unigrams: {}",
                new Object[] { ngramTotal, minValue, emitUnigrams });

        if (ngramTotal == -1) {
            throw new IllegalStateException("No NGRAM_TOTAL available in job config");
        }
        mos = new MultipleOutputs<Text, DoubleWritable>(context);
    }

    public AssocReducer() {
        this.assocCalculator = new DiceCallback();
    }

    /**
     * plug in an alternate LL implementation, used for testing
     * 
     * @param ll
     *            the LL to use.
     */
    AssocReducer(AssocCallback ll) {
        this.assocCalculator = ll;
    }

    /**
     * provide interface so the input to the DICE calculation can be captured
     * for validation in unit testing
     */
    public interface AssocCallback {
        double assoc(long k11, long k12, long k21, long k22);
    }

    /** concrete implementation delegates to LogLikelihood class */
    public static final class ConcreteLLCallback implements AssocCallback {
        @Override
        public double assoc(long k11, long k12, long k21, long k22) {

            return LogLikelihood.logLikelihoodRatio(k11, k12, k21, k22);
        }
    }

    public static final class DiceCallback implements AssocCallback {

        @Override
        public double assoc(long k11, long k12, long k21, long k22) {
            return (2.0 * k11) / (2 * k11 + k12 + k21);
        }

    }

    public static final class PMICallback implements AssocCallback {

        @Override
        public double assoc(long k11, long k12, long k21, long k22) {
            double total = k11 + k12 + k21 + k22;
            // expected values :
            double e11 = (k11 + k12) * (k11 + k21) / total;
            double e12 = (k12 + k11) * (k12 + k22) / total;
            double e21 = (k21 + k22) * (k21 + k11) / total;
            // double e22 = (double) ((k22 + k21) * (k22 + k12)) / (double)
            // total;

            // PMI is log_2(P(a|b)/P(a)*P(b))
            // ngramfreq/total/(b/total)
            return Math.log(e11 / (e12) * (e21)) / Math.log(2);
            // return Math.log(k11 * k22 / (k11+k21 * k11+k12)) / Math.log(2);
        }
    }

    public static final class ChiSquareCallback implements AssocCallback {

        @Override
        public double assoc(long k11, long k12, long k21, long k22) {
            double n = k11 + k12 + k21 + k22;
            return (n * (k11 * k22 - k21 * k12) / ((k11 + k21) * (k12 + k22) * (k11 + k12) * (k21 + k22)));
        }
    }

}