pt.ua.ieeta.nero.crf.CRFModel.java Source code

Java tutorial

Introduction

Here is the source code for pt.ua.ieeta.nero.crf.CRFModel.java

Source

/*
 * Gimli - High-performance and multi-corpus recognition of biomedical
 * entity names
 *
 * Copyright (C) 2011 David Campos, Universidade de Aveiro, Instituto de
 * Engenharia Electrnica e Telemtica de Aveiro
 *
 * Gimli is licensed under the Creative Commons
 * Attribution-NonCommercial-ShareAlike 3.0 Unported License. To view a copy of
 * this license, visit http://creativecommons.org/licenses/by-nc-sa/3.0/.
 *
 * Gimli is a free software, you are free to copy, distribute, change and
 * transmit it. However, you may not use Gimli for commercial purposes.
 *
 * Gimli is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.
 *
 */
package pt.ua.ieeta.nero.crf;

import pt.ua.ieeta.nero.feaure.targets.BIOFeature;
import cc.mallet.fst.*;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.constraints.OneLabelKLGEConstraints;
import cc.mallet.pipe.*;
import pt.ua.tm.gimli.config.ModelConfig;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Maths;
import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Level;
import java.util.regex.Pattern;
import org.apache.commons.lang.NotImplementedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import pt.ua.ieeta.nero.FSOptimisation;
import pt.ua.ieeta.nero.external.evaluator.Performance;
import pt.ua.ieeta.nero.feature.metrics.InfoGainUtil;
import pt.ua.ieeta.nero.sa.EvolvingSolution;
import pt.ua.ieeta.nero.feature.pipe.PipeBuilder;
import pt.ua.ieeta.nero.feaure.targets.*;
import pt.ua.tm.gimli.config.Constants;
import pt.ua.tm.gimli.corpus.Corpus;
import pt.ua.tm.gimli.exception.GimliException;
import pt.ua.tm.gimli.config.Constants.Parsing;
import pt.ua.tm.gimli.model.CRFBase;
import pt.ua.tm.gimli.util.FileUtil;

/**
 * The CRF model used by Gimli, providing features to train and test the models.
 *
 * @author David Campos (<a
 * href="mailto:david.campos@ua.pt">david.campos@ua.pt</a>)
 * @version 1.0
 * @since 1.0
 */
public class CRFModel extends CRFBase {

    /**
     * {@link Logger} to be used in the class.
     */
    private static Logger logger = LoggerFactory.getLogger(CRFModel.class);
    private CRF supervisedCRF;
    private static final Pattern forbiddenPattern = Pattern
            .compile(Constants.LabelTag.O + "," + Constants.LabelTag.I);

    /**
     * Constructor.
     *
     * @param config Model configuration.
     * @param parsing Parsing direction.
     */
    public CRFModel(final ModelConfig config, final Parsing parsing) {
        super(config, parsing);
        supervisedCRF = null;
    }

    /**
     * Constructor that loads the model from an input file.
     *
     * @param config Model configuration.
     * @param parsing Parsing direction.
     * @param file File that contains the model.
     * @throws GimliException Problem reading the model from file.
     */
    public CRFModel(final ModelConfig config, final Parsing parsing, final InputStream input)
            throws GimliException {
        super(config, parsing, input);
        supervisedCRF = null;
    }

    public CRF getSupervisedCRF() {
        return supervisedCRF;
    }

    public void setSupervisedCRF(CRF supervisedCRF) {
        this.supervisedCRF = supervisedCRF;
    }

    private HashMap<Integer, double[][]> loadGEConstraints(List<BIOFeature> features, InstanceList data) {

        HashMap<Integer, double[][]> constraints = new HashMap<Integer, double[][]>();

        for (IOptimizationTarget bioFeature : features) {
            BIOFeature f = (BIOFeature) bioFeature;

            // Get feature index
            int featureIndex = data.getDataAlphabet().lookupIndex(f.getName(), false);
            if (featureIndex == -1) {
                logger.error("Feature {} not found in the alphabet!", f.getName());
                continue;
                //throw new RuntimeException("BIOFeature " + f.getName() + " not found in the alphabet!");
            }

            // Initiate probabilities
            double[][] probs = new double[data.getTargetAlphabet().size()][2];
            for (int i = 0; i < probs.length; i++) {
                Arrays.fill(probs[i], Double.NEGATIVE_INFINITY);
            }

            // Get B
            int li = data.getTargetAlphabet().lookupIndex("B", false);
            assert (li != -1) : "B";
            probs[li][0] = probs[li][1] = f.getB();

            //Get I
            li = data.getTargetAlphabet().lookupIndex("I", false);
            assert (li != -1) : "I";
            probs[li][0] = probs[li][1] = f.getI();

            //Get O
            li = data.getTargetAlphabet().lookupIndex("O", false);
            assert (li != -1) : "O";
            probs[li][0] = probs[li][1] = f.getO();

            // Add constraint
            constraints.put(featureIndex, probs);
        }
        return constraints;
    }

    public int train(final InstanceList train, final InstanceList unlabeled, final EvolvingSolution solution,
            final List<BIOFeature> features, final int iterations) {
        // Get evolving solution
        List<IOptimizationTarget> targets = solution.getFeatureList();
        double ge = ((GEWeightOptimizationTarget) targets.get(0)).getGEWeight();
        double gvp = ((GPVOptimizationTarget) targets.get(1)).getGPVWeight();
        int numFeatures = ((NrFeaturesOptimizationTarget) targets.get(2)).getNumFeatures();

        List<BIOFeature> featuresSubSet = features.subList(0, numFeatures);

        // Load constraints

        ArrayList constraintsList = new ArrayList();
        if (unlabeled != null && features != null) {
            HashMap<Integer, double[][]> constraints = loadGEConstraints(featuresSubSet, train);

            // Set OneLabelKL constraints OneLabelKLGEConstraints 
            OneLabelKLGEConstraints geConstraints = new OneLabelKLGEConstraints();
            for (int fi : constraints.keySet()) {
                double[][] dist = constraints.get(fi);

                boolean allSame = true;
                double sum = 0;

                double[] prob = new double[dist.length];
                for (int li = 0; li < dist.length; li++) {
                    prob[li] = dist[li][0];
                    if (!Maths.almostEquals(dist[li][0], dist[li][1])) {
                        allSame = false;
                        break;
                    } else if (Double.isInfinite(prob[li])) {
                        prob[li] = 0;
                    }
                    sum += prob[li];
                }

                if (!allSame) {
                    throw new RuntimeException("A KL divergence penalty cannot be used with target ranges!");
                }
                if (!Maths.almostEquals(sum, 1)) {
                    throw new RuntimeException("Targets must sum to 1 when using a KL divergence penalty!");
                }

                geConstraints.addConstraint(fi, prob, 1);
            }
            constraintsList.add(geConstraints);
        }

        // Set label map
        StateLabelMap map = new StateLabelMap(train.getTargetAlphabet(), true);

        CRF crf = getCRF();

        // Optimazable gradients
        /*
         * Optimizable.ByGradientValue[] opts; if (unlabeled != null && features
         * != null) { opts = new Optimizable.ByGradientValue[]{ new
         * CRFOptimizableByGE(crf, constraintsList, unlabeled, map, 8), new
         * CRFOptimizableByLabelLikelihood(crf, train),}; } else { opts = new
         * Optimizable.ByGradientValue[]{ //new CRFOptimizableByGE(crf,
         * constraintsList, unlabeled, map, numThreads, 1.0), new
         * CRFOptimizableByLabelLikelihood(crf, train) }; }
         */

        //        logger.info("TOTAL SIZE OF ALPHABET AFTER FEATURE SELECTION): {}", train.getDataAlphabet().size());

        // Train

        if (unlabeled != null && features != null) {
            //CRFTrainerByValueGradients crfTrainer = new CRFTrainerByValueGradients(crf, opts);
            //crfTrainer.train(train, iterations);

            MyCRFTrainerByLikelihoodAndGE2 crfTrainer = new MyCRFTrainerByLikelihoodAndGE2(crf, constraintsList,
                    map);
            crfTrainer.setNumThreads(8);
            crfTrainer.setGEWeight(ge);
            crfTrainer.setGaussianPriorVariance(gvp);
            crfTrainer.setInitSupervised(true);
            crfTrainer.train(train, unlabeled, iterations);

            //this.supervisedCRF = crfTrainer.getSupervisedCRF();

            // Set CRF
            setCRF(crf);
            return crfTrainer.getIteration() - 1;

        } else {
            CRFTrainerByThreadedLabelLikelihood crfTrainer = new CRFTrainerByThreadedLabelLikelihood(crf, 8);
            crfTrainer.train(train, iterations);
            crfTrainer.shutdown();

            // Set CRF
            setCRF(crf);
            return crfTrainer.getIteration() - 1;
        }

    }

    /**
     * Train the CRF model.
     *
     * @throws GimliException Problem training the model.
     */
    @Override
    public void train(final Corpus corpus) throws GimliException {
        throw new NotImplementedException("Use the other train method");
    }

    public static InstanceList loadUnlabeledData(InputStream input, Pipe p) throws GimliException {

        InstanceList instances = new InstanceList(p);
        try {
            BufferedReader br = new BufferedReader(new InputStreamReader(input));
            String line;
            String text = "";
            int i = 0;
            while ((line = br.readLine()) != null) {
                if (line.equals("") || line.equals("\n")) {
                    instances.addThruPipe(new Instance(text, null, i++, null));
                    text = "";
                } else {
                    text += line;
                    text += "\n";
                }
            }
            br.close();
        } catch (IOException ex) {
            throw new GimliException("Problem reading unlabeled data.", ex);
        }

        return instances;
    }

    public static void main(String[] args) {
        try {
            String trainFile = "resources/corpus/bc2gm/train/corpus.gz";
            String devFile = "resources/corpus/bc2gm/dev/corpus.gz";
            String geneFile = "resources/corpus/bc2gm/dev/annotations";
            String testFile = "resources/corpus/bc2gm/test/corpus.gz";
            String unlabeledFile = "resources" + File.separator + "corpus" + File.separator + "unlabeled_1k.gz";
            String configFile = "config" + File.separator + "bc_semi.config";

            ModelConfig mc = new ModelConfig(configFile);
            mc.print();

            Corpus train = new Corpus(Constants.LabelFormat.BIO, Constants.EntityType.protein,
                    FileUtil.getFile(new FileInputStream(trainFile)));

            // Get test data
            //PipeBuilder pb = new PipeBuilder(mc);
            //pb.initialise();
            //pb.finalise(false);
            //Pipe p = pb.getPipe();
            Corpus dev = new Corpus(Constants.LabelFormat.BIO, Constants.EntityType.protein,
                    FileUtil.getFile(new FileInputStream(devFile)));
            //Corpus test = new Corpus(Constants.LabelFormat.BIO, Constants.EntityType.protein, FileUtil.getFile(new FileInputStream(testFile)));
            //InstanceList devInstances = dev.toModelFormatTrain(p);
            //InstanceList unlabeled = loadUnlabeledData(FileUtil.getFile(new FileInputStream(unlabeledFile)), p);

            // Get top features
            PipeBuilder pb = new PipeBuilder(mc);
            pb.initialise();
            pb.finalise(true);
            Pipe p = pb.getPipe();
            InstanceList trainInstances = train.toModelFormatTrain(p);
            int totalFeatures = trainInstances.getDataAlphabet().size();
            InfoGainUtil igf = new InfoGainUtil(trainInstances);

            //double[] sizes = {0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0};
            //double[] sizes = {0.8, 0.9, 1.0};
            double[] sizes = { 0.6, 0.7, 0.8, 0.9 };
            Performance[] results = new Performance[sizes.length];
            double[] numFeatures = new double[sizes.length];

            for (int j = 0; j < sizes.length; j++) {

                int feat = (int) (sizes[j] * totalFeatures) - 5;
                List<String> features = igf.getFeatures(feat);

                numFeatures[j] = feat;

                /*
                 * if (j == 0) { for (String f : features) { logger.info("{}",
                 * f); } }
                 */

                // Get data
                pb = new PipeBuilder(mc);
                pb.initialise();
                pb.addFeatureSelector(features);
                pb.finalise(false);
                p = pb.getPipe();

                Performance per = new Performance();
                CRF crf = FSOptimisation.getSupervisedCRF(train.toModelFormatTrain(p), mc, dev, geneFile, per);

                results[j] = per;

                /*
                 * InstanceList newInstList = train.toModelFormatTrain(p);
                 *
                 * // Set CRF int order = mc.getOrder() + 1; int[] orders = new
                 * int[order]; for (int i = 0; i < order; i++) { orders[i] = i;
                 * }
                 *
                 * CRF crf = new CRF(newInstList.getPipe(), (Pipe) null); String
                 * startStateName = crf.addOrderNStates( newInstList, orders,
                 * null, // "defaults" parameter; see mallet javadoc "O",
                 * forbiddenPattern, null, true); // true for a fully connected
                 * CRF
                 *
                 * for (int i = 0; i < crf.numStates(); i++) {
                 * crf.getState(i).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT);
                 * } crf.getState(startStateName).setInitialWeight(0.0);
                 * crf.setWeightsDimensionAsIn(newInstList, false); CRFModel
                 * model = new CRFModel(mc, Parsing.FW); model.setCRF(crf);
                 *
                 * //model.train(trainInstances, unlabeled,
                 * Nero.createBaseSolution(), Integer.MAX_VALUE); int
                 * numIterations = model.train(newInstList, null, null,
                 * Integer.MAX_VALUE, null); model.write(new
                 * GZIPOutputStream(new
                 * FileOutputStream("resources/model/bc2gm_o1_fw_" + sizes[j] +
                 * ".gz")));
                 *
                 * iterations[j] = numIterations;
                 *
                 * //InstanceList devInstances = dev.toModelFormatTrain(p);
                 * //results[j] = model.getF1(devInstances);
                 *
                 * Annotator an = new Annotator(dev); an.annotate(model);
                 *
                 * // Pre-process annotated corpus
                 * Parentheses.processRemoving(dev); Abbreviation.process(dev);
                 *
                 * String annotations = "resources/silver/bc2gm_dev_o1_fw_" +
                 * sizes[j]; BCWriter bw = new BCWriter(); bw.write(dev, new
                 * FileOutputStream(annotations));
                 *
                 * BC2Evaluator eval = new BC2Evaluator(geneFile, geneFile,
                 * annotations); results[j] = eval.getPerformance();
                 */
                /*
                 * an = new Annotator(test); an.annotate(model); bw = new
                 * BCWriter(); bw.write(test, new
                 * FileOutputStream("resources/silver/bc2gm_test_o1_fw"));
                 */

                //InstanceList devInstances = dev.toModelFormatTrain(p);
                //results[j] = model.getF1(devInstances);

                logger.info("{} ({}) - {} in {} - {} iterations", new Object[] { sizes[j], numFeatures[j],
                        results[j], results[j].getTime(), results[j].getNumIterations() });
            }

            logger.info("");
            logger.info("");
            for (int j = 0; j < sizes.length; j++) {
                logger.info("{} ({}) - {} in {} - {} iterations", new Object[] { sizes[j], numFeatures[j],
                        results[j], results[j].getTime(), results[j].getNumIterations() });
            }

            System.out.println("DONE!");
            System.exit(0);

        } catch (IOException ex) {
            java.util.logging.Logger.getLogger(CRFModel.class.getName()).log(Level.SEVERE, null, ex);
        } catch (GimliException ex) {
            ex.printStackTrace();
        }

    }

    @Override
    public Pipe getFeaturePipe() throws GimliException {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}