com.googlecode.clearnlp.component.AbstractStatisticalComponent.java Source code

Java tutorial

Introduction

Here is the source code for com.googlecode.clearnlp.component.AbstractStatisticalComponent.java

Source

/**
 * Copyright (c) 2009/09-2012/08, Regents of the University of Colorado
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright notice, this
 *    list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
/**
 * Copyright 2012/09-2013/04, University of Massachusetts Amherst
 * Copyright 2013/05-Present, IPSoft Inc.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. 
 */
package com.googlecode.clearnlp.component;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

import org.apache.commons.compress.utils.IOUtils;

import com.googlecode.clearnlp.classification.feature.FtrTemplate;
import com.googlecode.clearnlp.classification.feature.FtrToken;
import com.googlecode.clearnlp.classification.feature.JointFtrXml;
import com.googlecode.clearnlp.classification.model.ONStringModel;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.dependency.DEPArc;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.reader.AbstractColumnReader;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTOutput;
import com.googlecode.clearnlp.util.pair.Pair;

/**
 * @since 1.3.0
 * @author Jinho D. Choi ({@code jdchoi77@gmail.com})
 */
abstract public class AbstractStatisticalComponent extends AbstractComponent {
    protected StringTrainSpace[] s_spaces;
    protected StringModel[] s_models;
    protected JointFtrXml[] f_xmls;

    protected DEPTree d_tree;
    protected int t_size; // size of d_tree

    protected DEPNode[] lm_deps, rm_deps;
    protected DEPNode[] ln_sibs, rn_sibs;

    //   ====================================== CONSTRUCTORS ======================================

    public AbstractStatisticalComponent() {
    }

    /** Constructs a component for collecting lexica. */
    public AbstractStatisticalComponent(JointFtrXml[] xmls) {
        i_flag = FLAG_LEXICA;
        f_xmls = xmls;
    }

    /** Constructs a component for training. */
    public AbstractStatisticalComponent(JointFtrXml[] xmls, StringTrainSpace[] spaces, Object[] lexica) {
        i_flag = FLAG_TRAIN;
        f_xmls = xmls;
        s_spaces = spaces;

        initLexia(lexica);
    }

    /** Constructs a component for developing. */
    public AbstractStatisticalComponent(JointFtrXml[] xmls, StringModel[] models, Object[] lexica) {
        i_flag = FLAG_DEVELOP;
        f_xmls = xmls;
        s_models = models;

        initLexia(lexica);
    }

    /** Constructs a component for decoding. */
    public AbstractStatisticalComponent(ZipInputStream zin) {
        i_flag = FLAG_DECODE;

        loadModels(zin);
    }

    /** Constructs a component for bootstrapping. */
    public AbstractStatisticalComponent(JointFtrXml[] xmls, StringTrainSpace[] spaces, StringModel[] models,
            Object[] lexica) {
        i_flag = FLAG_BOOTSTRAP;
        f_xmls = xmls;
        s_spaces = spaces;
        s_models = models;

        initLexia(lexica);
    }

    /** Initializes lexica used for this component. */
    abstract protected void initLexia(Object[] lexica);

    //   ====================================== LOAD/SAVE MODELS ======================================

    /** Loads all models of this joint-component. */
    abstract public void loadModels(ZipInputStream zin);

    protected void loadDefaultConfiguration(ZipInputStream zin) throws Exception {
        BufferedReader fin = UTInput.createBufferedReader(zin);
        int i, mSize = Integer.parseInt(fin.readLine());

        LOG.info("Loading configuration.\n");
        s_models = new StringModel[mSize];

        for (i = 0; i < mSize; i++)
            s_models[i] = new StringModel();
    }

    /** Called by {@link AbstractStatisticalComponent#loadModels(ZipInputStream)}}. */
    protected ByteArrayInputStream loadFeatureTemplates(ZipInputStream zin, int index) throws Exception {
        LOG.info("Loading feature templates.\n");

        BufferedReader fin = UTInput.createBufferedReader(zin);
        ByteArrayInputStream template = getFeatureTemplates(fin);

        f_xmls[index] = new JointFtrXml(template);
        return template;
    }

    protected ByteArrayInputStream getFeatureTemplates(BufferedReader fin) throws IOException {
        StringBuilder build = new StringBuilder();
        String line;

        while ((line = fin.readLine()) != null) {
            build.append(line);
            build.append("\n");
        }

        return new ByteArrayInputStream(build.toString().getBytes());
    }

    /** Called by {@link AbstractStatisticalComponent#loadModels(ZipInputStream)}}. */
    protected void loadStatisticalModels(ZipInputStream zin, int index) throws Exception {
        BufferedReader fin = UTInput.createBufferedReader(zin);
        s_models[index].load(fin);
        //   s_models[index] = new StringModel(fin);
    }

    protected void loadWeightVector(ZipInputStream zin, int index) throws Exception {
        ObjectInputStream oin = new ObjectInputStream(new BufferedInputStream(zin));
        s_models[index].loadWeightVector(oin);
    }

    /** For online decoders. */
    protected void loadOnlineModels(ZipInputStream zin, int index, double alpha, double rho) throws Exception {
        BufferedReader fin = UTInput.createBufferedReader(zin);
        s_models[index] = new ONStringModel(fin, alpha, rho);
    }

    /** Saves all models of this joint-component. */
    abstract public void saveModels(ZipOutputStream zout);

    protected void saveDefaultConfiguration(ZipOutputStream zout, String entryName) throws Exception {
        zout.putNextEntry(new ZipEntry(entryName));
        PrintStream fout = UTOutput.createPrintBufferedStream(zout);
        LOG.info("Saving configuration.\n");

        fout.println(s_models.length);

        fout.flush();
        zout.closeEntry();
    }

    /** Called by {@link AbstractStatisticalComponent#saveModels(ZipOutputStream)}}. */
    protected void saveFeatureTemplates(ZipOutputStream zout, String entryName) throws Exception {
        int i, size = f_xmls.length;
        PrintStream fout;
        LOG.info("Saving feature templates.\n");

        for (i = 0; i < size; i++) {
            zout.putNextEntry(new ZipEntry(entryName + i));
            fout = UTOutput.createPrintBufferedStream(zout);
            IOUtils.copy(UTInput.toInputStream(f_xmls[i].toString()), fout);
            fout.flush();
            zout.closeEntry();
        }
    }

    /** Called by {@link AbstractStatisticalComponent#saveModels(ZipOutputStream)}}. */
    protected void saveStatisticalModels(ZipOutputStream zout, String entryName) throws Exception {
        int i, size = s_models.length;
        PrintStream fout;

        for (i = 0; i < size; i++) {
            zout.putNextEntry(new ZipEntry(entryName + i));
            fout = UTOutput.createPrintBufferedStream(zout);
            s_models[i].save(fout);
            fout.flush();
            zout.closeEntry();
        }
    }

    protected void saveWeightVector(ZipOutputStream zout, String entryName) throws Exception {
        int i, size = s_models.length;
        ObjectOutputStream oout;

        for (i = 0; i < size; i++) {
            zout.putNextEntry(new ZipEntry(entryName + i));
            oout = new ObjectOutputStream(new BufferedOutputStream(zout));
            s_models[i].saveWeightVector(oout);
            oout.flush();
            zout.closeEntry();
        }
    }

    //   ====================================== INITIALIZATION ======================================

    /** Initializes dependency arcs of all nodes. */
    protected void initArcs() {
        DEPNode curr, prev, next;
        List<DEPArc> deps;
        DEPArc lmd, rmd;
        int i, j, len;

        lm_deps = new DEPNode[t_size];
        rm_deps = new DEPNode[t_size];
        ln_sibs = new DEPNode[t_size];
        rn_sibs = new DEPNode[t_size];

        d_tree.setDependents();

        for (i = 1; i < t_size; i++) {
            deps = d_tree.get(i).getDependents();
            if (deps.isEmpty())
                continue;

            len = deps.size();
            lmd = deps.get(0);
            rmd = deps.get(len - 1);

            if (lmd.getNode().id < i)
                lm_deps[i] = lmd.getNode();
            if (rmd.getNode().id > i)
                rm_deps[i] = rmd.getNode();

            for (j = 1; j < len; j++) {
                curr = deps.get(j).getNode();
                prev = deps.get(j - 1).getNode();

                if (ln_sibs[curr.id] == null || ln_sibs[curr.id].id < prev.id)
                    ln_sibs[curr.id] = prev;
            }

            for (j = 0; j < len - 1; j++) {
                curr = deps.get(j).getNode();
                next = deps.get(j + 1).getNode();

                if (rn_sibs[curr.id] == null || rn_sibs[curr.id].id > next.id)
                    rn_sibs[curr.id] = next;
            }
        }
    }

    //   ====================================== GETTERS/SETTERS ======================================

    /** @return all training spaces of this joint-components. */
    public StringTrainSpace[] getTrainSpaces() {
        return s_spaces;
    }

    /** @return all models of this joint-components. */
    public StringModel[] getModels() {
        return s_models;
    }

    /** @return all objects containing lexica. */
    abstract public Object[] getLexica();

    //   ====================================== PROCESS ======================================

    /** Counts the number of correctly classified labels. */
    abstract public void countAccuracy(int[] counts);

    //   ====================================== FEATURE EXTRACTION ======================================

    /** @return a field of the specific feature token (e.g., lemma, pos-tag). */
    abstract protected String getField(FtrToken token);

    /** @return multiple fields of the specific feature token (e.g., lemma, pos-tag). */
    abstract protected String[] getFields(FtrToken token);

    /** @param the dependency node that is not {@code null}. */
    protected String getDefaultField(FtrToken token, DEPNode node) {
        Matcher m;

        if (token.isField(JointFtrXml.F_FORM)) {
            return node.form;
        } else if (token.isField(JointFtrXml.F_SIMPLIFIED_FORM)) {
            return node.simplifiedForm;
        } else if (token.isField(JointFtrXml.F_LEMMA)) {
            return node.lemma;
        } else if (token.isField(JointFtrXml.F_POS)) {
            return node.pos;
        } else if (token.isField(JointFtrXml.F_DEPREL)) {
            return node.getLabel();
        } else if ((m = JointFtrXml.P_FEAT.matcher(token.field)).find()) {
            return node.getFeat(m.group(1));
        }

        return null;
    }

    protected String[] getDefaultFields(FtrToken token, DEPNode node) {
        if (token.isField(JointFtrXml.F_DEPREL_SET)) {
            return getDeprelSet(node.getDependents());
        }

        return null;
    }

    /** @return a feature vector using the specific feature template. */
    protected StringFeatureVector getFeatureVector(JointFtrXml xml) {
        StringFeatureVector vector = new StringFeatureVector();

        for (FtrTemplate template : xml.getFtrTemplates())
            addFeatures(vector, template);

        return vector;
    }

    /** Called by {@link AbstractStatisticalComponent#getFeatureVector(JointFtrXml)}. */
    private void addFeatures(StringFeatureVector vector, FtrTemplate template) {
        FtrToken[] tokens = template.tokens;
        int i, size = tokens.length;

        if (template.isSetFeature()) {
            String[][] fields = new String[size][];
            String[] tmp;

            for (i = 0; i < size; i++) {
                tmp = getFields(tokens[i]);
                if (tmp == null)
                    return;
                fields[i] = tmp;
            }

            addFeatures(vector, template.type, fields, 0, "");
        } else {
            StringBuilder build = new StringBuilder();
            String field;

            for (i = 0; i < size; i++) {
                field = getField(tokens[i]);
                if (field == null)
                    return;

                if (i > 0)
                    build.append(AbstractColumnReader.BLANK_COLUMN);
                build.append(field);
            }

            vector.addFeature(template.type, build.toString());
        }
    }

    /** Called by {@link AbstractStatisticalComponent#getFeatureVector(JointFtrXml)}. */
    private void addFeatures(StringFeatureVector vector, String type, String[][] fields, int index, String prev) {
        if (index < fields.length) {
            for (String field : fields[index]) {
                if (prev.isEmpty())
                    addFeatures(vector, type, fields, index + 1, field);
                else
                    addFeatures(vector, type, fields, index + 1, prev + AbstractColumnReader.BLANK_COLUMN + field);
            }
        } else
            vector.addFeature(type, prev);
    }

    protected List<Pair<String, StringFeatureVector>> getTrimmedInstances(
            List<Pair<String, StringFeatureVector>> insts) {
        List<Pair<String, StringFeatureVector>> nInsts = new ArrayList<Pair<String, StringFeatureVector>>();
        Set<String> set = new HashSet<String>();
        String key;

        for (Pair<String, StringFeatureVector> p : insts) {
            key = p.o1 + " " + p.o2.toString();

            if (!set.contains(key)) {
                nInsts.add(p);
                set.add(key);
            }
        }

        return nInsts;
    }

    //   ====================================== RULES ======================================

    protected List<Map<String, String[]>> getRules(BufferedReader fin) {
        Pattern space = Pattern.compile(" "), tab = Pattern.compile("\t");
        List<Map<String, String[]>> rules = null;
        String[] tmp, val;
        String line;
        int i, ngram;

        try {
            ngram = Integer.parseInt(fin.readLine());
            rules = new ArrayList<Map<String, String[]>>(ngram);

            for (i = 0; i < ngram; i++)
                rules.add(new HashMap<String, String[]>());

            while ((line = fin.readLine()) != null) {
                tmp = tab.split(line);
                val = space.split(tmp[1]);

                if (val.length <= ngram)
                    rules.get(val.length - 1).put(tmp[0].trim(), val);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return rules;
    }

    protected String[] getRules(List<Map<String, String[]>> list, int currId) {
        StringBuilder build = new StringBuilder();
        int i, j, ngram = list.size();
        String[] tmp, rules = null;

        for (i = currId, j = 0; i < t_size && j < ngram; i++, j++) {
            if (j > 0)
                build.append(" ");
            build.append(d_tree.get(i).form);

            tmp = list.get(j).get(build.toString());
            if (tmp != null)
                rules = tmp;
        }

        return rules;
    }
}