edu.stanford.nlp.ie.crf.CRFClassifier.java Source code

Java tutorial

Introduction

Here is the source code for edu.stanford.nlp.ie.crf.CRFClassifier.java

Source

// CRFClassifier -- a probabilistic (CRF) sequence model, mainly used for NER.
// Copyright (c) 2002-2016 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software Foundation,
// Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    Support/Questions: java-nlp-user@lists.stanford.edu
//    Licensing: java-nlp-support@lists.stanford.edu

package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.logging.Redwood;

import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.*;
import java.util.stream.Collectors;
import java.util.zip.GZIPOutputStream;

/**
 * Class for sequence classification using a Conditional Random Field model.
 * The code has functionality for different document formats, but when
 * using the standard {@link edu.stanford.nlp.sequences.ColumnDocumentReaderAndWriter} for training
 * or testing models, input files are expected to
 * be one token per line with the columns indicating things like the word,
 * POS, chunk, and answer class.  The default for
 * {@code ColumnDocumentReaderAndWriter} training data is 3 column input,
 * with the columns containing a word, its POS, and its gold class, but
 * this can be specified via the {@code map} property.
 * <p>
 * When run on a file with {@code -textFile} or {@code -textFiles},
 * the file is assumed to be plain English text (or perhaps simple HTML/XML),
 * and a reasonable attempt is made at English tokenization by
 * {@link PlainTextDocumentReaderAndWriter}.  The class used to read
 * the text can be changed with -plainTextDocumentReaderAndWriter.
 * Extra options can be supplied to the tokenizer using the
 * -tokenizerOptions flag.
 * <p>
 * To read from stdin, use the flag -readStdin.  The same
 * reader/writer will be used as for -textFile.
 * <p>
 * <b>Typical command-line usage</b>
 * <p>
 * For running a trained model with a provided serialized classifier on a
 * text file:
 * <p>
 * {@code java -mx500m edu.stanford.nlp.ie.crf.CRFClassifier -loadClassifier
 * conll.ner.gz -textFile sampleSentences.txt }
 * <p>
 * When specifying all parameters in a properties file (train, test, or
 * runtime):
 * <p>
 * {@code java -mx1g edu.stanford.nlp.ie.crf.CRFClassifier -prop propFile }
 * <p>
 * To train and test a simple NER model from the command line:
 * <p>
 * {@code java -mx1g edu.stanford.nlp.ie.crf.CRFClassifier -trainFile trainFile -testFile testFile -macro > output }
 * <p>
 * To train with multiple files:
 * <p>
 * {@code java -mx1g edu.stanford.nlp.ie.crf.CRFClassifier -trainFileList file1,file2,... -testFile testFile -macro > output }
 * <p>
 * To test on multiple files, use the -testFiles option and a comma
 * separated list.
 * <p>
 * Features are defined by a {@link edu.stanford.nlp.sequences.FeatureFactory}.
 * {@link NERFeatureFactory} is used by default, and you should look
 * there for feature templates and properties or flags that will cause
 * certain features to be used when training an NER classifier. There
 * are also various feature factories for Chinese word segmentation
 * such as {@link edu.stanford.nlp.wordseg.ChineseSegmenterFeatureFactory}.
 * Features are specified either
 * by a Properties file (which is the recommended method) or by flags on the
 * command line. The flags are read into a {@link SeqClassifierFlags} object,
 * which the user need not be concerned with, unless wishing to add new
 * features.
 * <p>
 * CRFClassifier may also be used programmatically. When creating
 * a new instance, you <i>must</i> specify a Properties object. You may then
 * call train methods to train a classifier, or load a classifier. The other way
 * to get a CRFClassifier is to deserialize one via the static
 * {@link CRFClassifier#getClassifier(String)} methods, which return a
 * deserialized classifier. You may then tag (classify the items of) documents
 * using either the assorted {@code classify()} methods here or the additional
 * ones in {@link AbstractSequenceClassifier}.
 * Probabilities assigned by the CRF can be interrogated using either the
 * {@code printProbsDocument()} or {@code getCliqueTrees()} methods.
 *
 * @author Jenny Finkel
 * @author Sonal Gupta (made the class generic)
 * @author Mengqiu Wang (LOP implementation and non-linear CRF implementation)
 */
public class CRFClassifier<IN extends CoreMap> extends AbstractSequenceClassifier<IN> {

    /** A logger for this class */
    private static final Redwood.RedwoodChannels log = Redwood.channels(CRFClassifier.class);

    // TODO(mengqiu) need to move the embedding lookup and capitalization features into a FeatureFactory

    List<Index<CRFLabel>> labelIndices;
    Index<String> tagIndex;
    private Pair<double[][], double[][]> entityMatrices;

    CliquePotentialFunction cliquePotentialFunction;
    HasCliquePotentialFunction cliquePotentialFunctionHelper;

    /** Parameter weights of the classifier.  weights[featureIndex][labelIndex] */
    double[][] weights;

    /** index the features of CRF */
    Index<String> featureIndex;
    /** caches the featureIndex */
    int[] map;
    Random random = new Random(2147483647L);
    Index<Integer> nodeFeatureIndicesMap;
    Index<Integer> edgeFeatureIndicesMap;

    private Map<String, double[]> embeddings; // = null;

    /**
     * Name of default serialized classifier resource to look for in a jar file.
     */
    public static final String DEFAULT_CLASSIFIER = "edu/stanford/nlp/models/ner/english.all.3class.distsim.crf.ser.gz";
    private static final boolean VERBOSE = false;

    /**
     * Fields for grouping features
     */
    private Pattern suffixPatt = Pattern.compile(".+?((?:-[A-Z]+)+)\\|.*C");
    private Index<String> templateGroupIndex;
    private Map<Integer, Integer> featureIndexToTemplateIndex;

    // Label dictionary for fast decoding
    private LabelDictionary labelDictionary;

    // List selftraindatums = new ArrayList();

    protected CRFClassifier() {
        super(new SeqClassifierFlags());
    }

    public CRFClassifier(Properties props) {
        super(props);
    }

    public CRFClassifier(SeqClassifierFlags flags) {
        super(flags);
    }

    /**
     * Makes a copy of the crf classifier
     */
    public CRFClassifier(CRFClassifier<IN> crf) {
        super(crf.flags);
        this.windowSize = crf.windowSize;
        this.featureFactories = crf.featureFactories;
        this.pad = crf.pad;
        if (crf.knownLCWords == null) {
            this.knownLCWords = new MaxSizeConcurrentHashSet<>(crf.flags.maxAdditionalKnownLCWords);
        } else {
            this.knownLCWords = new MaxSizeConcurrentHashSet<>(crf.knownLCWords);
            this.knownLCWords.setMaxSize(this.knownLCWords.size() + crf.flags.maxAdditionalKnownLCWords);
        }
        this.featureIndex = (crf.featureIndex != null) ? new HashIndex<>(crf.featureIndex.objectsList()) : null;
        this.classIndex = (crf.classIndex != null) ? new HashIndex<>(crf.classIndex.objectsList()) : null;
        if (crf.labelIndices != null) {
            this.labelIndices = new ArrayList<>(crf.labelIndices.size());
            for (int i = 0; i < crf.labelIndices.size(); i++) {
                this.labelIndices.add(
                        (crf.labelIndices.get(i) != null) ? new HashIndex<>(crf.labelIndices.get(i).objectsList())
                                : null);
            }
        } else {
            this.labelIndices = null;
        }
        this.cliquePotentialFunction = crf.cliquePotentialFunction;
    }

    /**
     * Returns the total number of weights associated with this classifier.
     *
     * @return number of weights
     */
    public int getNumWeights() {
        if (weights == null)
            return 0;
        int numWeights = 0;
        for (double[] wts : weights) {
            numWeights += wts.length;
        }
        return numWeights;
    }

    /**
     * Get index of featureType for feature indexed by i. (featureType index is
     * used to index labelIndices to get labels.)
     *
     * @param i Feature index
     * @return index of featureType
     */
    private int getFeatureTypeIndex(int i) {
        return getFeatureTypeIndex(featureIndex.get(i));
    }

    /**
     * Get index of featureType for feature based on the feature string
     * (featureType index used to index labelIndices to get labels)
     *
     * @param feature Feature string
     * @return index of featureType
     */
    private static int getFeatureTypeIndex(String feature) {
        if (feature.endsWith("|C")) {
            return 0;
        } else if (feature.endsWith("|CpC")) {
            return 1;
        } else if (feature.endsWith("|Cp2C")) {
            return 2;
        } else if (feature.endsWith("|Cp3C")) {
            return 3;
        } else if (feature.endsWith("|Cp4C")) {
            return 4;
        } else if (feature.endsWith("|Cp5C")) {
            return 5;
        } else {
            throw new RuntimeException("Unknown feature type " + feature);
        }
    }

    /**
     * Scales the weights of this CRFClassifier by the specified weight.
     *
     * @param scale The scale to multiply by
     */
    public void scaleWeights(double scale) {
        for (int i = 0; i < weights.length; i++) {
            for (int j = 0; j < weights[i].length; j++) {
                weights[i][j] *= scale;
            }
        }
    }

    /**
     * Combines weights from another crf (scaled by weight) into this CRF's
     * weights (assumes that this CRF's indices have already been updated to
     * include features/labels from the other crf)
     *
     * @param crf Other CRF whose weights to combine into this CRF
     * @param weight Amount to scale the other CRF's weights by
     */
    private void combineWeights(CRFClassifier<IN> crf, double weight) {
        int numFeatures = featureIndex.size();
        int oldNumFeatures = weights.length;

        // Create a map of other crf labels to this crf labels
        Map<CRFLabel, CRFLabel> crfLabelMap = Generics.newHashMap();
        for (int i = 0; i < crf.labelIndices.size(); i++) {
            for (int j = 0; j < crf.labelIndices.get(i).size(); j++) {
                CRFLabel labels = crf.labelIndices.get(i).get(j);
                int[] newLabelIndices = new int[i + 1];
                for (int ci = 0; ci <= i; ci++) {
                    String classLabel = crf.classIndex.get(labels.getLabel()[ci]);
                    newLabelIndices[ci] = this.classIndex.indexOf(classLabel);
                }
                CRFLabel newLabels = new CRFLabel(newLabelIndices);
                crfLabelMap.put(labels, newLabels);
                int k = this.labelIndices.get(i).indexOf(newLabels); // IMPORTANT: the indexing is needed, even when not printed out!
                // log.info("LabelIndices " + i + " " + labels + ": " + j +
                // " mapped to " + k);
            }
        }

        // Create map of featureIndex to featureTypeIndex
        map = new int[numFeatures];
        for (int i = 0; i < numFeatures; i++) {
            map[i] = getFeatureTypeIndex(i);
        }

        // Create new weights
        double[][] newWeights = new double[numFeatures][];
        for (int i = 0; i < numFeatures; i++) {
            int length = labelIndices.get(map[i]).size();
            newWeights[i] = new double[length];
            if (i < oldNumFeatures) {
                assert (length >= weights[i].length);
                System.arraycopy(weights[i], 0, newWeights[i], 0, weights[i].length);
            }
        }
        weights = newWeights;

        // Get original weight indices from other crf and weight them in
        // depending on the type of the feature, different number of weights is
        // associated with it
        for (int i = 0; i < crf.weights.length; i++) {
            String feature = crf.featureIndex.get(i);
            int newIndex = featureIndex.indexOf(feature);
            // Check weights are okay dimension
            if (weights[newIndex].length < crf.weights[i].length) {
                throw new RuntimeException("Incompatible CRFClassifier: weight length mismatch for feature "
                        + newIndex + ": " + featureIndex.get(newIndex) + " (also feature " + i + ": "
                        + crf.featureIndex.get(i) + ") " + ", len1=" + weights[newIndex].length + ", len2="
                        + crf.weights[i].length);
            }
            int featureTypeIndex = map[newIndex];
            for (int j = 0; j < crf.weights[i].length; j++) {
                CRFLabel labels = crf.labelIndices.get(featureTypeIndex).get(j);
                CRFLabel newLabels = crfLabelMap.get(labels);
                int k = this.labelIndices.get(featureTypeIndex).indexOf(newLabels);
                weights[newIndex][k] += crf.weights[i][j] * weight;
            }
        }
    }

    /**
     * Combines weighted crf with this crf.
     *
     * @param crf Other CRF whose weights to combine into this CRF
     * @param weight Amount to scale the other CRF's weights by
     */
    public void combine(CRFClassifier<IN> crf, double weight) {
        Timing timer = new Timing();

        // Check the CRFClassifiers are compatible
        if (!this.pad.equals(crf.pad)) {
            throw new RuntimeException("Incompatible CRFClassifier: pad does not match");
        }
        if (this.windowSize != crf.windowSize) {
            throw new RuntimeException("Incompatible CRFClassifier: windowSize does not match");
        }
        if (this.labelIndices.size() != crf.labelIndices.size()) {
            // Should match since this should be same as the windowSize
            throw new RuntimeException("Incompatible CRFClassifier: labelIndices length does not match");
        }
        this.classIndex.addAll(crf.classIndex.objectsList());

        // Combine weights of the other classifier with this classifier,
        // weighing the other classifier's weights by weight
        // First merge the feature indices
        int oldNumFeatures1 = this.featureIndex.size();
        int oldNumFeatures2 = crf.featureIndex.size();
        int oldNumWeights1 = this.getNumWeights();
        int oldNumWeights2 = crf.getNumWeights();
        this.featureIndex.addAll(crf.featureIndex.objectsList());
        this.knownLCWords.addAll(crf.knownLCWords);
        assert (weights.length == oldNumFeatures1);

        // Combine weights of this classifier with other classifier
        for (int i = 0; i < labelIndices.size(); i++) {
            this.labelIndices.get(i).addAll(crf.labelIndices.get(i).objectsList());
        }
        log.info("Combining weights: will automatically match labelIndices");
        combineWeights(crf, weight);

        int numFeatures = featureIndex.size();
        int numWeights = getNumWeights();
        long elapsedMs = timer.stop();
        log.info("numFeatures: orig1=" + oldNumFeatures1 + ", orig2=" + oldNumFeatures2 + ", combined="
                + numFeatures);
        log.info("numWeights: orig1=" + oldNumWeights1 + ", orig2=" + oldNumWeights2 + ", combined=" + numWeights);
        log.info("Time to combine CRFClassifier: " + Timing.toSecondsString(elapsedMs) + " seconds");
    }

    public void dropFeaturesBelowThreshold(double threshold) {
        Index<String> newFeatureIndex = new HashIndex<>();
        for (int i = 0; i < weights.length; i++) {
            double smallest = weights[i][0];
            double biggest = weights[i][0];
            for (int j = 1; j < weights[i].length; j++) {
                if (weights[i][j] > biggest) {
                    biggest = weights[i][j];
                }
                if (weights[i][j] < smallest) {
                    smallest = weights[i][j];
                }
                if (biggest - smallest > threshold) {
                    newFeatureIndex.add(featureIndex.get(i));
                    break;
                }
            }
        }

        int[] newMap = new int[newFeatureIndex.size()];
        for (int i = 0; i < newMap.length; i++) {
            int index = featureIndex.indexOf(newFeatureIndex.get(i));
            newMap[i] = map[index];
        }
        map = newMap;
        featureIndex = newFeatureIndex;
    }

    /**
     * Convert a document List into arrays storing the data features and labels.
     * This is used at test time.
     *
     * @param document Testing documents
     * @return A Triple, where the first element is an int[][][] representing the
     *         data, the second element is an int[] representing the labels, and
     *         the third element is a double[][][] representing the feature values (optionally null)
     */
    public Triple<int[][][], int[], double[][][]> documentToDataAndLabels(List<IN> document) {
        int docSize = document.size();
        // first index is position in the document also the index of the
        // clique/factor table
        // second index is the number of elements in the clique/window these
        // features are for (starting with last element)
        // third index is position of the feature in the array that holds them.
        // An element in data[j][k][m] is the feature index of the mth feature occurring in
        // position k of the jth clique
        int[][][] data = new int[docSize][windowSize][];
        double[][][] featureVals = new double[docSize][windowSize][];
        // index is the position in the document.
        // element in labels[j] is the index of the correct label (if it exists) at
        // position j of document
        int[] labels = new int[docSize];

        if (flags.useReverse) {
            Collections.reverse(document);
        }

        // log.info("docSize:"+docSize);
        for (int j = 0; j < docSize; j++) {
            int[][] data_j = data[j];
            double[][] featureVals_j = featureVals[j];
            CRFDatum<Collection<String>, CRFLabel> d = makeDatum(document, j, featureFactories);

            List<Collection<String>> features = d.asFeatures();
            List<double[]> featureValList = d.asFeatureVals();
            for (int k = 0, fSize = features.size(); k < fSize; k++) {
                Collection<String> cliqueFeatures = features.get(k);
                int[] data_jk = data_j[k] = new int[cliqueFeatures.size()];
                if (featureValList != null && k < featureValList.size()) { // CRFBiasedClassifier.makeDatum causes null
                    featureVals_j[k] = featureValList.get(k);
                }
                int m = 0;
                for (String feature : cliqueFeatures) {
                    int index = featureIndex.indexOf(feature);
                    if (index >= 0) {
                        data_jk[m] = index;
                        m++;
                    } else {
                        // this is where we end up when we do feature threshold cutoffs
                    }
                }

                if (m < data_j[k].length) {
                    data_j[k] = Arrays.copyOf(data_j[k], m);
                    if (featureVals_j[k] != null) {
                        featureVals_j[k] = Arrays.copyOf(featureVals_j[k], m);
                    }
                }
            }

            IN wi = document.get(j);
            labels[j] = classIndex.indexOf(wi.get(CoreAnnotations.AnswerAnnotation.class));
        }

        if (flags.useReverse) {
            Collections.reverse(document);
        }

        return new Triple<>(data, labels, featureVals);
    }

    public void printLabelInformation(String testFile, DocumentReaderAndWriter<IN> readerAndWriter)
            throws Exception {
        ObjectBank<List<IN>> documents = makeObjectBankFromFile(testFile, readerAndWriter);
        for (List<IN> document : documents) {
            printLabelValue(document);
        }
    }

    public void printLabelValue(List<IN> document) {
        if (flags.useReverse) {
            Collections.reverse(document);
        }

        NumberFormat nf = new DecimalFormat();

        List<String> classes = new ArrayList<>();
        for (int i = 0; i < classIndex.size(); i++) {
            classes.add(classIndex.get(i));
        }
        String[] columnHeaders = classes.toArray(new String[classes.size()]);

        // log.info("docSize:"+docSize);
        for (int j = 0; j < document.size(); j++) {

            System.out.println("--== " + document.get(j).get(CoreAnnotations.TextAnnotation.class) + " ==--");

            List<String[]> lines = new ArrayList<>();
            List<String> rowHeaders = new ArrayList<>();
            List<String> line = new ArrayList<>();

            for (int p = 0; p < labelIndices.size(); p++) {
                if (j + p >= document.size()) {
                    continue;
                }
                CRFDatum<Collection<String>, CRFLabel> d = makeDatum(document, j + p, featureFactories);

                List<Collection<String>> features = d.asFeatures();
                for (int k = p, fSize = features.size(); k < fSize; k++) {
                    Collection<String> cliqueFeatures = features.get(k);
                    for (String feature : cliqueFeatures) {
                        int index = featureIndex.indexOf(feature);
                        if (index >= 0) {
                            // line.add(feature+"["+(-p)+"]");
                            rowHeaders.add(feature + '[' + (-p) + ']');
                            double[] values = new double[labelIndices.get(0).size()];
                            for (CRFLabel label : labelIndices.get(k)) {
                                int[] l = label.getLabel();
                                double v = weights[index][labelIndices.get(k).indexOf(label)];
                                values[l[l.length - 1 - p]] += v;
                            }
                            for (double value : values) {
                                line.add(nf.format(value));
                            }
                            lines.add(line.toArray(new String[line.size()]));
                            line = new ArrayList<>();
                        }
                    }
                }
                // lines.add(Collections.<String>emptyList());
                System.out.println(StringUtils.makeTextTable(lines.toArray(new String[lines.size()][0]),
                        rowHeaders.toArray(new String[rowHeaders.size()]), columnHeaders, 0, 1, true));
                System.out.println();
            }
            // log.info(edu.stanford.nlp.util.StringUtils.join(lines,"\n"));
        }

        if (flags.useReverse) {
            Collections.reverse(document);
        }
    }

    /**
     * Convert an ObjectBank to arrays of data features and labels.
     * This version is used at training time.
     *
     * @return A Triple, where the first element is an int[][][][] representing the
     *         data, the second element is an int[][] representing the labels, and
     *         the third element is a double[][][][] representing the feature values
     *         which could be optionally left as null.
     */
    public Triple<int[][][][], int[][], double[][][][]> documentsToDataAndLabels(Collection<List<IN>> documents) {

        // first index is the number of the document
        // second index is position in the document also the index of the
        // clique/factor table
        // third index is the number of elements in the clique/window these features
        // are for (starting with last element)
        // fourth index is position of the feature in the array that holds them
        // element in data[i][j][k][m] is the index of the mth feature occurring in
        // position k of the jth clique of the ith document
        // int[][][][] data = new int[documentsSize][][][];
        int numDocs = documents.size();
        List<int[][][]> data = new ArrayList<>(numDocs);
        List<double[][][]> featureVal = flags.useEmbedding ? new ArrayList<>(numDocs) : null;

        // first index is the number of the document
        // second index is the position in the document
        // element in labels[i][j] is the index of the correct label (if it exists)
        // at position j in document i
        // int[][] labels = new int[documentsSize][];
        List<int[]> labels = new ArrayList<>(numDocs);

        int numDatums = 0;

        for (List<IN> doc : documents) {
            Triple<int[][][], int[], double[][][]> docTriple = documentToDataAndLabels(doc);
            data.add(docTriple.first());
            labels.add(docTriple.second());
            if (flags.useEmbedding)
                featureVal.add(docTriple.third());
            numDatums += doc.size();
        }

        if (labels.size() != numDocs || data.size() != numDocs) {
            throw new AssertionError("Inexplicable miscalculation in the size of some arrays");
        }

        log.info("numClasses: " + classIndex.size() + ' ' + classIndex);
        log.info("numDocuments: " + data.size());
        log.info("numDatums: " + numDatums);
        log.info("numFeatures: " + featureIndex.size());
        printFeatures();

        double[][][][] featureValArr = null;
        if (flags.useEmbedding)
            featureValArr = featureVal.toArray(new double[data.size()][][][]);

        return new Triple<>(data.toArray(new int[data.size()][][][]), labels.toArray(new int[labels.size()][]),
                featureValArr);
    }

    /**
     * Convert an ObjectBank to corresponding collection of data features and
     * labels. This version is used at test time.
     *
     * @return A List of pairs, one for each document, where the first element is
     *         an int[][][] representing the data and the second element is an
     *         int[] representing the labels.
     */
    public List<Triple<int[][][], int[], double[][][]>> documentsToDataAndLabelsList(
            Collection<List<IN>> documents) {
        int numDatums = 0;

        List<Triple<int[][][], int[], double[][][]>> docList = new ArrayList<>(documents.size());
        for (List<IN> doc : documents) {
            Triple<int[][][], int[], double[][][]> docTriple = documentToDataAndLabels(doc);
            docList.add(docTriple);
            numDatums += doc.size();
        }

        log.info("numClasses: " + classIndex.size() + ' ' + classIndex);
        log.info("numDocuments: " + docList.size());
        log.info("numDatums: " + numDatums);
        log.info("numFeatures: " + featureIndex.size());
        return docList;
    }

    protected void printFeatures() {
        if (flags.printFeatures == null) {
            return;
        }
        try {
            String enc = flags.inputEncoding;
            if (flags.inputEncoding == null) {
                log.info("flags.inputEncoding doesn't exist, using UTF-8 as default");
                enc = "UTF-8";
            }

            PrintWriter pw = new PrintWriter(
                    new OutputStreamWriter(new FileOutputStream("features-" + flags.printFeatures + ".txt"), enc),
                    true);
            for (String feat : featureIndex) {
                pw.println(feat);
            }
            pw.close();
        } catch (IOException ioe) {
            ioe.printStackTrace();
        }
    }

    /**
     * This routine builds the {@code labelIndices} which give the
     * empirically legal label sequences (of length (order) at most
     * {@code windowSize}) and the {@code classIndex}, which indexes
     * known answer classes.
     *
     * @param ob The training data: Read from an ObjectBank, each item in it is a
     *          {@code List<CoreLabel>}.
     */
    protected void makeAnswerArraysAndTagIndex(Collection<List<IN>> ob) {
        // TODO: slow?
        boolean useFeatureCountThresh = flags.featureCountThresh > 1;

        Set<String>[] featureIndices = new HashSet[windowSize];
        Map<String, Integer>[] featureCountIndices = null;
        for (int i = 0; i < windowSize; i++) {
            featureIndices[i] = Generics.newHashSet();
        }
        if (useFeatureCountThresh) {
            featureCountIndices = new HashMap[windowSize];
            for (int i = 0; i < windowSize; i++) {
                featureCountIndices[i] = Generics.newHashMap();
            }
        }

        labelIndices = new ArrayList<>(windowSize);
        for (int i = 0; i < windowSize; i++) {
            labelIndices.add(new HashIndex<>());
        }

        Index<CRFLabel> labelIndex = labelIndices.get(windowSize - 1);

        if (classIndex == null)
            classIndex = new HashIndex<>();
        // classIndex.add("O");
        classIndex.add(flags.backgroundSymbol);

        Set<String>[] seenBackgroundFeatures = new HashSet[2];
        seenBackgroundFeatures[0] = Generics.newHashSet();
        seenBackgroundFeatures[1] = Generics.newHashSet();

        int wordCount = 0;

        if (flags.labelDictionaryCutoff > 0) {
            this.labelDictionary = new LabelDictionary();
        }

        for (List<IN> doc : ob) {
            if (flags.useReverse) {
                Collections.reverse(doc);
            }

            // create the full set of labels in classIndex
            // note: update to use addAll later
            for (IN token : doc) {
                wordCount++;
                String ans = token.get(CoreAnnotations.AnswerAnnotation.class);
                if (ans == null || ans.isEmpty()) {
                    throw new IllegalArgumentException("Word " + wordCount + " (\""
                            + token.get(CoreAnnotations.TextAnnotation.class) + "\") has a blank answer");
                }
                classIndex.add(ans);
                if (labelDictionary != null) {
                    String observation = token.get(CoreAnnotations.TextAnnotation.class);
                    labelDictionary.increment(observation, ans);
                }
            }

            for (int j = 0, docSize = doc.size(); j < docSize; j++) {
                CRFDatum<Collection<String>, CRFLabel> d = makeDatum(doc, j, featureFactories);
                labelIndex.add(d.label());

                List<Collection<String>> features = d.asFeatures();
                for (int k = 0, fSize = features.size(); k < fSize; k++) {
                    Collection<String> cliqueFeatures = features.get(k);
                    if (k < 2 && flags.removeBackgroundSingletonFeatures) {
                        String ans = doc.get(j).get(CoreAnnotations.AnswerAnnotation.class);
                        boolean background = ans.equals(flags.backgroundSymbol);
                        if (k == 1 && j > 0 && background) {
                            ans = doc.get(j - 1).get(CoreAnnotations.AnswerAnnotation.class);
                            background = ans.equals(flags.backgroundSymbol);
                        }
                        if (background) {
                            for (String f : cliqueFeatures) {
                                if (useFeatureCountThresh) {
                                    if (!featureCountIndices[k].containsKey(f)) {
                                        if (seenBackgroundFeatures[k].contains(f)) {
                                            seenBackgroundFeatures[k].remove(f);
                                            featureCountIndices[k].put(f, 1);
                                        } else {
                                            seenBackgroundFeatures[k].add(f);
                                        }
                                    }
                                } else {
                                    if (!featureIndices[k].contains(f)) {
                                        if (seenBackgroundFeatures[k].contains(f)) {
                                            seenBackgroundFeatures[k].remove(f);
                                            featureIndices[k].add(f);
                                        } else {
                                            seenBackgroundFeatures[k].add(f);
                                        }
                                    }
                                }
                            }
                        } else {
                            seenBackgroundFeatures[k].removeAll(cliqueFeatures);
                            if (useFeatureCountThresh) {
                                Map<String, Integer> fCountIndex = featureCountIndices[k];
                                for (String f : cliqueFeatures) {
                                    if (fCountIndex.containsKey(f))
                                        fCountIndex.put(f, fCountIndex.get(f) + 1);
                                    else
                                        fCountIndex.put(f, 1);
                                }
                            } else {
                                featureIndices[k].addAll(cliqueFeatures);
                            }
                        }
                    } else {
                        if (useFeatureCountThresh) {
                            Map<String, Integer> fCountIndex = featureCountIndices[k];
                            for (String f : cliqueFeatures) {
                                if (fCountIndex.containsKey(f))
                                    fCountIndex.put(f, fCountIndex.get(f) + 1);
                                else
                                    fCountIndex.put(f, 1);
                            }
                        } else {
                            featureIndices[k].addAll(cliqueFeatures);
                        }
                    }
                }
            }

            if (flags.useReverse) {
                Collections.reverse(doc);
            }
        }

        if (useFeatureCountThresh) {
            int numFeatures = 0;
            for (int i = 0; i < windowSize; i++) {
                numFeatures += featureCountIndices[i].size();
            }
            log.info("Before feature count thresholding, numFeatures = " + numFeatures);
            for (int i = 0; i < windowSize; i++) {
                for (Iterator<Map.Entry<String, Integer>> it = featureCountIndices[i].entrySet().iterator(); it
                        .hasNext();) {
                    Map.Entry<String, Integer> entry = it.next();
                    if (entry.getValue() < flags.featureCountThresh) {
                        it.remove();
                    }
                }
                featureIndices[i].addAll(featureCountIndices[i].keySet());
                featureCountIndices[i] = null;
            }
        }

        int numFeatures = 0;
        for (int i = 0; i < windowSize; i++) {
            numFeatures += featureIndices[i].size();
        }
        log.info("numFeatures = " + numFeatures);

        featureIndex = new HashIndex<>();
        map = new int[numFeatures];

        if (flags.groupByFeatureTemplate) {
            templateGroupIndex = new HashIndex<>();
            featureIndexToTemplateIndex = new HashMap<>();
        }

        for (int i = 0; i < windowSize; i++) {
            Index<Integer> featureIndexMap = new HashIndex<>();

            featureIndex.addAll(featureIndices[i]);
            for (String str : featureIndices[i]) {
                int index = featureIndex.indexOf(str);
                map[index] = i;
                featureIndexMap.add(index);

                // grouping features by template
                if (flags.groupByFeatureTemplate) {
                    Matcher m = suffixPatt.matcher(str);
                    String groupSuffix = (m.matches() ? m.group(1) : "NoTemplate") + "-c:" + i;
                    int groupIndex = templateGroupIndex.addToIndex(groupSuffix);
                    featureIndexToTemplateIndex.put(index, groupIndex);
                }
            }
            // todo [cdm 2014]: Talk to Mengqiu about this; it seems like it only supports first order CRF
            if (i == 0) {
                nodeFeatureIndicesMap = featureIndexMap;
                // log.info("setting nodeFeatureIndicesMap, size="+nodeFeatureIndicesMap.size());
            } else {
                edgeFeatureIndicesMap = featureIndexMap;
                // log.info("setting edgeFeatureIndicesMap, size="+edgeFeatureIndicesMap.size());
            }
        }

        if (flags.numOfFeatureSlices > 0) {
            log.info("Taking " + flags.numOfFeatureSlices + " out of " + flags.totalFeatureSlice
                    + " slices of node features for training");
            pruneNodeFeatureIndices(flags.totalFeatureSlice, flags.numOfFeatureSlices);
        }

        if (flags.useObservedSequencesOnly) {
            for (int i = 0, liSize = labelIndex.size(); i < liSize; i++) {
                CRFLabel label = labelIndex.get(i);
                for (int j = windowSize - 2; j >= 0; j--) {
                    label = label.getOneSmallerLabel();
                    labelIndices.get(j).add(label);
                }
            }
        } else {
            for (int i = 0; i < labelIndices.size(); i++) {
                labelIndices.set(i, allLabels(i + 1, classIndex));
            }
        }

        if (VERBOSE) {
            for (int i = 0, fiSize = featureIndex.size(); i < fiSize; i++) {
                System.out.println(i + ": " + featureIndex.get(i));
            }
        }
        if (labelDictionary != null) {
            labelDictionary.lock(flags.labelDictionaryCutoff, classIndex);
        }
    }

    protected static Index<CRFLabel> allLabels(int window, Index<String> classIndex) {
        int[] label = new int[window];
        int numClasses = classIndex.size();
        Index<CRFLabel> labelIndex = new HashIndex<>();
        OUTER: while (true) {
            CRFLabel l = new CRFLabel(label);
            labelIndex.add(l);
            label = Arrays.copyOf(label, window);
            for (int j = 0; j < label.length; j++) {
                if (label[j]++ < numClasses)
                    break;
                label[j] = 0;
                if (j == label.length - 1)
                    break OUTER;
            }
        }
        return labelIndex;
    }

    /**
     * Makes a CRFDatum by producing features and a label from input data at a
     * specific position, using the provided factory.
     *
     * @param info The input data. Particular feature factories might look for arbitrary keys in the IN items.
     * @param loc The position to build a datum at
     * @param featureFactories The FeatureFactories to use to extract features
     * @return The constructed CRFDatum
     */
    public CRFDatum<Collection<String>, CRFLabel> makeDatum(List<IN> info, int loc,
            List<FeatureFactory<IN>> featureFactories) {
        // pad.set(CoreAnnotations.AnswerAnnotation.class, flags.backgroundSymbol); // cdm: isn't this unnecessary, as this is how it's initialized in AbstractSequenceClassifier.reinit?
        PaddedList<IN> pInfo = new PaddedList<>(info, pad);

        ArrayList<Collection<String>> features = new ArrayList<>(windowSize);
        List<double[]> featureVals = flags.useEmbedding ? new ArrayList<>(1) : null;

        for (int i = 0; i < windowSize; i++) {
            List<String> featuresC = new ArrayList<>();
            if (flags.useEmbedding && i == 0) { // only activated for node features
                featureVals.add(makeDatumUsingEmbedding(info, loc, featureFactories, pInfo, featuresC));
            } else {
                FeatureFactory.eachClique(i, 0, c -> {
                    for (FeatureFactory<IN> featureFactory : featureFactories) {
                        featuresC.addAll(featureFactory.getCliqueFeatures(pInfo, loc, c)); //todo useless copy because of typing reasons
                    }
                });
            }
            features.add(featuresC);
        }

        int[] labels = new int[windowSize];

        for (int i = 0; i < windowSize; i++) {
            String answer = pInfo.get(loc + i - windowSize + 1).get(CoreAnnotations.AnswerAnnotation.class);
            labels[i] = classIndex.indexOf(answer);
        }

        printFeatureLists(pInfo.get(loc), features);

        CRFDatum<Collection<String>, CRFLabel> d = new CRFDatum<>(features, new CRFLabel(labels), featureVals);
        // log.info(d);
        return d;
    }

    private double[] makeDatumUsingEmbedding(List<IN> info, int loc, List<FeatureFactory<IN>> featureFactories,
            PaddedList<IN> pInfo, Collection<String> featuresC) {
        double[] featureValArr;
        List<double[]> embeddingList = new ArrayList<>();
        int concatEmbeddingLen = 0;
        String currentWord = null;
        for (int currLoc = loc - 2; currLoc <= loc + 2; currLoc++) {
            double[] embedding; // Initialized in cases below // = null;
            if (currLoc >= 0 && currLoc < info.size()) {
                currentWord = info.get(loc).get(CoreAnnotations.TextAnnotation.class);
                String word = currentWord.toLowerCase();
                word = word.replaceAll("(-)?\\d+(\\.\\d*)?", "0");
                embedding = embeddings.get(word);
                if (embedding == null)
                    embedding = embeddings.get("UNKNOWN");
            } else {
                embedding = embeddings.get("PADDING");
            }

            for (int e = 0; e < embedding.length; e++) {
                featuresC.add("EMBEDDING-(" + (currLoc - loc) + ")-" + e);
            }

            if (flags.addCapitalFeatures) {
                int numOfCapitalFeatures = 4;
                int currLen = embedding.length;
                embedding = Arrays.copyOf(embedding, currLen + numOfCapitalFeatures);
                for (int e = 0; e < numOfCapitalFeatures; e++)
                    featuresC.add("CAPITAL-(" + (currLoc - loc) + ")-" + e);

                if (currLoc >= 0 && currLoc < info.size()) { // skip PADDING
                    // check if word is all caps
                    if (currentWord.toUpperCase().equals(currentWord))
                        embedding[currLen] = 1;
                    else {
                        currLen += 1;
                        // check if word is all lower
                        if (currentWord.toLowerCase().equals(currentWord))
                            embedding[currLen] = 1;
                        else {
                            currLen += 1;
                            // check first letter cap
                            if (Character.isUpperCase(currentWord.charAt(0)))
                                embedding[currLen] = 1;
                            else {
                                currLen += 1;
                                // check if at least one non-initial letter is cap
                                String remainder = currentWord.substring(1);
                                if (!remainder.toLowerCase().equals(remainder))
                                    embedding[currLen] = 1;
                            }
                        }
                    }
                }
            }

            embeddingList.add(embedding);
            concatEmbeddingLen += embedding.length;
        }
        double[] concatEmbedding = new double[concatEmbeddingLen];
        int currPos = 0;
        for (double[] em : embeddingList) {
            System.arraycopy(em, 0, concatEmbedding, currPos, em.length);
            currPos += em.length;
        }

        if (flags.prependEmbedding) {
            FeatureFactory.eachClique(0, 0, c -> {
                for (FeatureFactory<IN> featureFactory : featureFactories) {
                    featuresC.addAll(featureFactory.getCliqueFeatures(pInfo, loc, c)); //todo useless copy because of typing reasons
                }
            });
            featureValArr = Arrays.copyOf(concatEmbedding, featuresC.size());
            Arrays.fill(featureValArr, concatEmbedding.length, featureValArr.length, 1.0);
        } else {
            featureValArr = concatEmbedding;
        }

        if (flags.addBiasToEmbedding) {
            featuresC.add("BIAS-FEATURE");
            featureValArr = Arrays.copyOf(featureValArr, featureValArr.length + 1);
            featureValArr[featureValArr.length - 1] = 1;
        }
        return featureValArr;
    }

    @Override
    public void dumpFeatures(Collection<List<IN>> docs) {
        if (flags.exportFeatures != null) {
            Timing timer = new Timing();
            CRFFeatureExporter<IN> featureExporter = new CRFFeatureExporter<>(this);
            featureExporter.printFeatures(flags.exportFeatures, docs);
            long elapsedMs = timer.stop();
            log.info("Time to export features: " + Timing.toSecondsString(elapsedMs) + " seconds");
        }
    }

    @Override
    public List<IN> classify(List<IN> document) {
        if (flags.doGibbs) {
            try {
                return classifyGibbs(document);
            } catch (Exception e) {
                throw new RuntimeException("Error running testGibbs inference!", e);
            }
        } else if (flags.crfType.equalsIgnoreCase("maxent")) {
            return classifyMaxEnt(document);
        } else {
            throw new RuntimeException("Unsupported inference type: " + flags.crfType);
        }
    }

    private List<IN> classify(List<IN> document, Triple<int[][][], int[], double[][][]> documentDataAndLabels) {
        if (flags.doGibbs) {
            try {
                return classifyGibbs(document, documentDataAndLabels);
            } catch (Exception e) {
                throw new RuntimeException("Error running testGibbs inference!", e);
            }
        } else if (flags.crfType.equalsIgnoreCase("maxent")) {
            return classifyMaxEnt(document, documentDataAndLabels);
        } else {
            throw new RuntimeException("Unsupported inference type: " + flags.crfType);
        }
    }

    /**
     * This method is supposed to be used by CRFClassifierEvaluator only, should not have global visibility.
     * The generic {@code classifyAndWriteAnswers} omits the second argument {@code documentDataAndLabels}.
     */
    void classifyAndWriteAnswers(Collection<List<IN>> documents,
            List<Triple<int[][][], int[], double[][][]>> documentDataAndLabels, PrintWriter printWriter,
            DocumentReaderAndWriter<IN> readerAndWriter) throws IOException {
        Timing timer = new Timing();

        Counter<String> entityTP = new ClassicCounter<>();
        Counter<String> entityFP = new ClassicCounter<>();
        Counter<String> entityFN = new ClassicCounter<>();
        boolean resultsCounted = true;

        int numWords = 0;
        int numDocs = 0;
        for (List<IN> doc : documents) {
            classify(doc, documentDataAndLabels.get(numDocs));
            numWords += doc.size();
            writeAnswers(doc, printWriter, readerAndWriter);
            resultsCounted = resultsCounted && countResults(doc, entityTP, entityFP, entityFN);
            numDocs++;
        }
        long millis = timer.stop();
        double wordspersec = numWords / (((double) millis) / 1000);
        NumberFormat nf = new DecimalFormat("0.00"); // easier way!
        if (!flags.suppressTestDebug)
            log.info(StringUtils.getShortClassName(this) + " tagged " + numWords + " words in " + numDocs
                    + " documents at " + nf.format(wordspersec) + " words per second.");
        if (resultsCounted && !flags.suppressTestDebug) {
            printResults(entityTP, entityFP, entityFN);
        }
    }

    @Override
    public SequenceModel getSequenceModel(List<IN> doc) {
        Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(doc);
        return getSequenceModel(p, doc);
    }

    private SequenceModel getSequenceModel(Triple<int[][][], int[], double[][][]> documentDataAndLabels,
            List<IN> document) {
        return labelDictionary == null ? new TestSequenceModel(getCliqueTree(documentDataAndLabels))
                : new TestSequenceModel(getCliqueTree(documentDataAndLabels), labelDictionary, document);
    }

    protected CliquePotentialFunction getCliquePotentialFunctionForTest() {
        if (cliquePotentialFunction == null) {
            cliquePotentialFunction = new LinearCliquePotentialFunction(weights);
        }
        return cliquePotentialFunction;
    }

    public void updateWeightsForTest(double[] x) {
        cliquePotentialFunction = cliquePotentialFunctionHelper.getCliquePotentialFunction(x);
    }

    /**
     * Do standard sequence inference, using either Viterbi or Beam inference
     * depending on the value of {@code flags.inferenceType}.
     *
     * @param document Document to classify. Classification happens in place.
     *          This document is modified.
     * @return The classified document
     */
    public List<IN> classifyMaxEnt(List<IN> document) {
        if (document.isEmpty()) {
            return document;
        }

        SequenceModel model = getSequenceModel(document);
        return classifyMaxEnt(document, model);
    }

    private List<IN> classifyMaxEnt(List<IN> document,
            Triple<int[][][], int[], double[][][]> documentDataAndLabels) {
        if (document.isEmpty()) {
            return document;
        }
        SequenceModel model = getSequenceModel(documentDataAndLabels, document);
        return classifyMaxEnt(document, model);
    }

    private List<IN> classifyMaxEnt(List<IN> document, SequenceModel model) {
        if (document.isEmpty()) {
            return document;
        }

        if (flags.inferenceType == null) {
            flags.inferenceType = "Viterbi";
        }

        BestSequenceFinder tagInference;
        if (flags.inferenceType.equalsIgnoreCase("Viterbi")) {
            tagInference = new ExactBestSequenceFinder();
        } else if (flags.inferenceType.equalsIgnoreCase("Beam")) {
            tagInference = new BeamBestSequenceFinder(flags.beamSize);
        } else {
            throw new RuntimeException(
                    "Unknown inference type: " + flags.inferenceType + ". Your options are Viterbi|Beam.");
        }

        int[] bestSequence = tagInference.bestSequence(model);

        if (flags.useReverse) {
            Collections.reverse(document);
        }
        for (int j = 0, docSize = document.size(); j < docSize; j++) {
            IN wi = document.get(j);
            String guess = classIndex.get(bestSequence[j + windowSize - 1]);
            wi.set(CoreAnnotations.AnswerAnnotation.class, guess);
            int index = classIndex.indexOf(guess);
            double guessProb = ((TestSequenceModel) model).labelProb(j, index);
            wi.set(CoreAnnotations.AnswerProbAnnotation.class, guessProb);
        }
        if (flags.useReverse) {
            Collections.reverse(document);
        }
        return document;
    }

    public List<IN> classifyGibbs(List<IN> document)
            throws ClassNotFoundException, SecurityException, NoSuchMethodException, IllegalArgumentException,
            InstantiationException, IllegalAccessException, InvocationTargetException {
        Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);
        return classifyGibbs(document, p);
    }

    public List<IN> classifyGibbs(List<IN> document, Triple<int[][][], int[], double[][][]> documentDataAndLabels)
            throws ClassNotFoundException, SecurityException, NoSuchMethodException, IllegalArgumentException,
            InstantiationException, IllegalAccessException, InvocationTargetException {
        // log.info("Testing using Gibbs sampling.");
        List<IN> newDocument = document; // reversed if necessary
        if (flags.useReverse) {
            Collections.reverse(document);
            newDocument = new ArrayList<>(document);
            Collections.reverse(document);
        }

        CRFCliqueTree<? extends CharSequence> cliqueTree = getCliqueTree(documentDataAndLabels);

        PriorModelFactory<IN> pmf = (PriorModelFactory<IN>) Class.forName(flags.priorModelFactory).newInstance();
        ListeningSequenceModel prior = pmf.getInstance(flags.backgroundSymbol, classIndex, tagIndex, newDocument,
                entityMatrices, flags);

        if (!flags.useUniformPrior) {
            throw new RuntimeException("no prior specified");
        }

        SequenceModel model = new FactoredSequenceModel(cliqueTree, prior);
        SequenceListener listener = new FactoredSequenceListener(cliqueTree, prior);

        SequenceGibbsSampler sampler = new SequenceGibbsSampler(0, 0, listener);
        int[] sequence = new int[cliqueTree.length()];

        if (flags.initViterbi) {
            TestSequenceModel testSequenceModel = new TestSequenceModel(cliqueTree);
            ExactBestSequenceFinder tagInference = new ExactBestSequenceFinder();
            int[] bestSequence = tagInference.bestSequence(testSequenceModel);
            System.arraycopy(bestSequence, windowSize - 1, sequence, 0, sequence.length);
        } else {
            int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
            System.arraycopy(initialSequence, 0, sequence, 0, sequence.length);
        }

        sampler.verbose = 0;

        if (flags.annealingType.equalsIgnoreCase("linear")) {
            sequence = sampler.findBestUsingAnnealing(model,
                    CoolingSchedule.getLinearSchedule(1.0, flags.numSamples), sequence);
        } else if (flags.annealingType.equalsIgnoreCase("exp")
                || flags.annealingType.equalsIgnoreCase("exponential")) {
            sequence = sampler.findBestUsingAnnealing(model,
                    CoolingSchedule.getExponentialSchedule(1.0, flags.annealingRate, flags.numSamples), sequence);
        } else {
            throw new RuntimeException("No annealing type specified");
        }

        if (flags.useReverse) {
            Collections.reverse(document);
        }

        for (int j = 0, dsize = newDocument.size(); j < dsize; j++) {
            IN wi = document.get(j);
            if (wi == null)
                throw new RuntimeException("");
            if (classIndex == null)
                throw new RuntimeException("");
            wi.set(CoreAnnotations.AnswerAnnotation.class, classIndex.get(sequence[j]));
        }

        if (flags.useReverse) {
            Collections.reverse(document);
        }

        return document;
    }

    /**
     * Takes a {@link List} of something that extends {@link CoreMap} and prints
     * the likelihood of each possible label at each point.
     *
     * @param document A {@link List} of something that extends CoreMap.
     * @return If verboseMode is set, a Triple of Counters recording classification decisions, else null.
     */
    @Override
    public Triple<Counter<Integer>, Counter<Integer>, TwoDimensionalCounter<Integer, String>> printProbsDocument(
            List<IN> document) {
        // TODO: Probably this would really be better with 11 bins, with edge ones from 0-0.5 and 0.95-1.0, a bit like 11-point ave precision
        final int numBins = 10;
        boolean verbose = flags.verboseMode;

        Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);
        CRFCliqueTree<String> cliqueTree = getCliqueTree(p);

        Counter<Integer> calibration = new ClassicCounter<>();
        Counter<Integer> correctByBin = new ClassicCounter<>();
        TwoDimensionalCounter<Integer, String> calibratedTokens = new TwoDimensionalCounter<>();

        // for (int i = 0; i < factorTables.length; i++) {
        for (int i = 0; i < cliqueTree.length(); i++) {
            IN wi = document.get(i);
            String token = wi.get(CoreAnnotations.TextAnnotation.class);
            String goldAnswer = wi.get(CoreAnnotations.GoldAnswerAnnotation.class);
            System.out.print(token);
            System.out.print('\t');
            System.out.print(goldAnswer);
            double maxProb = Double.NEGATIVE_INFINITY;
            String bestClass = "";
            for (String label : classIndex) {
                int index = classIndex.indexOf(label);
                // double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index));
                double prob = cliqueTree.prob(i, index);
                if (prob > maxProb) {
                    bestClass = label;
                }
                System.out.print('\t');
                System.out.print(label);
                System.out.print('=');
                System.out.print(prob);
                if (verbose) {
                    int binnedProb = (int) (prob * numBins);
                    if (binnedProb > (numBins - 1)) {
                        binnedProb = numBins - 1;
                    }
                    calibration.incrementCount(binnedProb);
                    if (label.equals(goldAnswer)) {
                        if (bestClass.equals(goldAnswer)) {
                            correctByBin.incrementCount(binnedProb);
                        }
                        if (!label.equals(flags.backgroundSymbol)) {
                            calibratedTokens.incrementCount(binnedProb, token);
                        }
                    }
                }
            }
            System.out.println();
        }
        if (verbose) {
            return new Triple<>(calibration, correctByBin, calibratedTokens);
        } else {
            return null;
        }
    }

    public List<Counter<String>> zeroOrderProbabilities(List<IN> document) {
        List<Counter<String>> ret = new ArrayList<>();
        Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);
        CRFCliqueTree<String> cliqueTree = getCliqueTree(p);
        for (int i = 0; i < cliqueTree.length(); i++) {
            Counter<String> ctr = new ClassicCounter<>();
            for (String label : classIndex) {
                int index = classIndex.indexOf(label);
                double prob = cliqueTree.prob(i, index);
                ctr.setCount(label, prob);
            }
            ret.add(ctr);
        }
        return ret;
    }

    /**
     * Takes the file, reads it in, and prints out the likelihood of each possible
     * label at each point. This gives a simple way to examine the probability
     * distributions of the CRF. See {@code getCliqueTrees()} for more.
     *
     * @param filename The path to the specified file
     */
    public void printFirstOrderProbs(String filename, DocumentReaderAndWriter<IN> readerAndWriter) {
        // only for the OCR data does this matter
        // flags.ocrTrain = false;

        ObjectBank<List<IN>> docs = makeObjectBankFromFile(filename, readerAndWriter);
        printFirstOrderProbsDocuments(docs);
    }

    /**
     * Takes a {@link List} of documents and prints the likelihood of each
     * possible label at each point.
     *
     * @param documents A {@link List} of {@link List} of INs.
     */
    public void printFirstOrderProbsDocuments(ObjectBank<List<IN>> documents) {
        for (List<IN> doc : documents) {
            printFirstOrderProbsDocument(doc);
            System.out.println();
        }
    }

    /**
     * Takes the file, reads it in, and prints out the factor table at each position.
     *
     * @param filename The path to the specified file
     */
    public void printFactorTable(String filename, DocumentReaderAndWriter<IN> readerAndWriter) {
        // only for the OCR data does this matter
        // flags.ocrTrain = false;

        ObjectBank<List<IN>> docs = makeObjectBankFromFile(filename, readerAndWriter);
        printFactorTableDocuments(docs);
    }

    /**
     * Takes a {@link List} of documents and prints the factor table
     * at each point.
     *
     * @param documents A {@link List} of {@link List} of INs.
     */
    public void printFactorTableDocuments(ObjectBank<List<IN>> documents) {
        for (List<IN> doc : documents) {
            printFactorTableDocument(doc);
            System.out.println();
        }
    }

    /**
     * Want to make arbitrary probability queries? Then this is the method for
     * you. Given the filename, it reads it in and breaks it into documents, and
     * then makes a CRFCliqueTree for each document. you can then ask the clique
     * tree for marginals and conditional probabilities of almost anything you want.
     */
    public List<CRFCliqueTree<String>> getCliqueTrees(String filename,
            DocumentReaderAndWriter<IN> readerAndWriter) {
        // only for the OCR data does this matter
        // flags.ocrTrain = false;

        List<CRFCliqueTree<String>> cts = new ArrayList<>();
        ObjectBank<List<IN>> docs = makeObjectBankFromFile(filename, readerAndWriter);
        for (List<IN> doc : docs) {
            cts.add(getCliqueTree(doc));
        }

        return cts;
    }

    public CRFCliqueTree<String> getCliqueTree(Triple<int[][][], int[], double[][][]> p) {
        int[][][] data = p.first();
        double[][][] featureVal = p.third();

        return CRFCliqueTree.getCalibratedCliqueTree(data, labelIndices, classIndex.size(), classIndex,
                flags.backgroundSymbol, getCliquePotentialFunctionForTest(), featureVal);
    }

    // This method should stay public
    @SuppressWarnings("WeakerAccess")
    public CRFCliqueTree<String> getCliqueTree(List<IN> document) {
        Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);
        return getCliqueTree(p);
    }

    // This method should stay public
    /**
     * Takes a {@link List} of something that extends {@link CoreMap} and prints
     * the factor table at each point.
     *
     * @param document A {@link List} of something that extends {@link CoreMap}.
     */
    @SuppressWarnings("WeakerAccess")
    public void printFactorTableDocument(List<IN> document) {

        CRFCliqueTree<String> cliqueTree = getCliqueTree(document);
        FactorTable[] factorTables = cliqueTree.getFactorTables();

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < factorTables.length; i++) {
            IN wi = document.get(i);
            sb.append(wi.get(CoreAnnotations.TextAnnotation.class));
            sb.append('\t');
            FactorTable table = factorTables[i];
            for (int j = 0; j < table.size(); j++) {
                int[] arr = table.toArray(j);
                sb.append(classIndex.get(arr[0]));
                sb.append(':');
                sb.append(classIndex.get(arr[1]));
                sb.append(':');
                sb.append(cliqueTree.logProb(i, arr));
                sb.append(' ');
            }
            sb.append('\n');
        }
        System.out.print(sb);
    }

    /**
     * Takes a {@link List} of something that extends {@link CoreMap} and prints
     * the likelihood of each possible label at each point.
     *
     * @param document A {@link List} of something that extends {@link CoreMap}.
     */
    public void printFirstOrderProbsDocument(List<IN> document) {

        CRFCliqueTree<String> cliqueTree = getCliqueTree(document);

        // for (int i = 0; i < factorTables.length; i++) {
        for (int i = 0; i < cliqueTree.length(); i++) {
            IN wi = document.get(i);
            System.out.print(wi.get(CoreAnnotations.TextAnnotation.class) + '\t');
            for (Iterator<String> iter = classIndex.iterator(); iter.hasNext();) {
                String label = iter.next();
                int index = classIndex.indexOf(label);
                if (i == 0) {
                    // double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index));
                    double prob = cliqueTree.prob(i, index);
                    System.out.print(label + '=' + prob);
                    if (iter.hasNext()) {
                        System.out.print("\t");
                    } else {
                        System.out.print("\n");
                    }
                } else {
                    for (Iterator<String> iter1 = classIndex.iterator(); iter1.hasNext();) {
                        String label1 = iter1.next();
                        int index1 = classIndex.indexOf(label1);
                        // double prob = Math.pow(Math.E, factorTables[i].logProbEnd(new
                        // int[]{index1, index}));
                        double prob = cliqueTree.prob(i, new int[] { index1, index });
                        System.out.print(label1 + '_' + label + '=' + prob);
                        if (iter.hasNext() || iter1.hasNext()) {
                            System.out.print("\t");
                        } else {
                            System.out.print("\n");
                        }
                    }
                }
            }
        }
    }

    /**
     * Load auxiliary data to be used in constructing features and labels
     * Intended to be overridden by subclasses
     */
    protected Collection<List<IN>> loadAuxiliaryData(Collection<List<IN>> docs,
            DocumentReaderAndWriter<IN> readerAndWriter) {
        return docs;
    }

    /** {@inheritDoc} */
    @Override
    public void train(Collection<List<IN>> objectBankWrapper, DocumentReaderAndWriter<IN> readerAndWriter) {
        Timing timer = new Timing();

        Collection<List<IN>> docs = new ArrayList<>();
        for (List<IN> doc : objectBankWrapper) {
            docs.add(doc);
        }

        if (flags.numOfSlices > 0) {
            log.info("Taking " + flags.numOfSlices + " out of " + flags.totalDataSlice
                    + " slices of data for training");
            List<List<IN>> docsToShuffle = new ArrayList<>();
            for (List<IN> doc : docs) {
                docsToShuffle.add(doc);
            }
            Collections.shuffle(docsToShuffle, random);
            int cutOff = (int) (docsToShuffle.size() / (flags.totalDataSlice + 0.0) * flags.numOfSlices);
            docs = docsToShuffle.subList(0, cutOff);
        }

        Collection<List<IN>> totalDocs = loadAuxiliaryData(docs, readerAndWriter);

        makeAnswerArraysAndTagIndex(totalDocs);

        long elapsedMs = timer.stop();
        log.info("Time to convert docs to feature indices: " + Timing.toSecondsString(elapsedMs) + " seconds");
        log.info("Current memory used: " + MemoryMonitor.getUsedMemoryString());

        if (flags.serializeClassIndexTo != null) {
            timer.start();
            serializeClassIndex(flags.serializeClassIndexTo);
            elapsedMs = timer.stop();
            log.info("Time to export class index : " + Timing.toSecondsString(elapsedMs) + " seconds");
        }

        if (flags.exportFeatures != null) {
            dumpFeatures(docs);
        }

        for (int i = 0; i <= flags.numTimesPruneFeatures; i++) {
            timer.start();
            Triple<int[][][][], int[][], double[][][][]> dataAndLabelsAndFeatureVals = documentsToDataAndLabels(
                    docs);
            elapsedMs = timer.stop();
            log.info("Time to convert docs to data/labels: " + Timing.toSecondsString(elapsedMs) + " seconds");
            log.info("Current memory used: " + MemoryMonitor.getUsedMemoryString());

            Evaluator[] evaluators = null;
            if (flags.evaluateIters > 0 || flags.terminateOnEvalImprovement) {
                List<Evaluator> evaluatorList = new ArrayList<>();
                if (flags.useMemoryEvaluator)
                    evaluatorList.add(new MemoryEvaluator());
                if (flags.evaluateTrain) {
                    CRFClassifierEvaluator<IN> crfEvaluator = new CRFClassifierEvaluator<>("Train set", this);
                    int[][][][] data = dataAndLabelsAndFeatureVals.first();
                    int[][] labels = dataAndLabelsAndFeatureVals.second();
                    double[][][][] featureVal = dataAndLabelsAndFeatureVals.third();
                    List<Triple<int[][][], int[], double[][][]>> trainDataAndLabels = new ArrayList<>(data.length);
                    for (int j = 0; j < data.length; j++) {
                        Triple<int[][][], int[], double[][][]> p = new Triple<>(data[j], labels[j], featureVal[j]);
                        trainDataAndLabels.add(p);
                    }
                    crfEvaluator.setTestData(docs, trainDataAndLabels);
                    if (flags.evalCmd.length() > 0)
                        crfEvaluator.setEvalCmd(flags.evalCmd);
                    evaluatorList.add(crfEvaluator);
                }
                if (flags.testFile != null) {
                    CRFClassifierEvaluator<IN> crfEvaluator = new CRFClassifierEvaluator<>(
                            "Test set (" + flags.testFile + ")", this);
                    ObjectBank<List<IN>> testObjBank = makeObjectBankFromFile(flags.testFile, readerAndWriter);
                    List<List<IN>> testDocs = new ArrayList<>(testObjBank);
                    List<Triple<int[][][], int[], double[][][]>> testDataAndLabels = documentsToDataAndLabelsList(
                            testDocs);
                    crfEvaluator.setTestData(testDocs, testDataAndLabels);
                    if (!flags.evalCmd.isEmpty()) {
                        crfEvaluator.setEvalCmd(flags.evalCmd);
                    }
                    evaluatorList.add(crfEvaluator);
                }
                if (flags.testFiles != null) {
                    String[] testFiles = flags.testFiles.split(",");
                    for (String testFile : testFiles) {
                        CRFClassifierEvaluator<IN> crfEvaluator = new CRFClassifierEvaluator<>(
                                "Test set (" + testFile + ')', this);
                        ObjectBank<List<IN>> testObjBank = makeObjectBankFromFile(testFile, readerAndWriter);
                        List<Triple<int[][][], int[], double[][][]>> testDataAndLabels = documentsToDataAndLabelsList(
                                testObjBank);
                        crfEvaluator.setTestData(testObjBank, testDataAndLabels);
                        if (!flags.evalCmd.isEmpty()) {
                            crfEvaluator.setEvalCmd(flags.evalCmd);
                        }
                        evaluatorList.add(crfEvaluator);
                    }
                }
                evaluators = new Evaluator[evaluatorList.size()];
                evaluatorList.toArray(evaluators);
            }

            if (flags.numTimesPruneFeatures == i) {
                docs = null; // hopefully saves memory
            }
            // save feature index to disk and read in later
            File featIndexFile = null;

            // CRFLogConditionalObjectiveFunction.featureIndex = featureIndex;
            // int numFeatures = featureIndex.size();
            if (flags.saveFeatureIndexToDisk) {
                try {
                    log.info("Writing feature index to temporary file.");
                    featIndexFile = IOUtils.writeObjectToTempFile(featureIndex, "featIndex" + i + ".tmp");
                    // featureIndex = null;
                } catch (IOException e) {
                    throw new RuntimeException("Could not open temporary feature index file for writing.");
                }
            }

            // first index is the number of the document
            // second index is position in the document also the index of the
            // clique/factor table
            // third index is the number of elements in the clique/window these
            // features are for (starting with last element)
            // fourth index is position of the feature in the array that holds them
            // element in data[i][j][k][m] is the index of the mth feature occurring
            // in position k of the jth clique of the ith document
            int[][][][] data = dataAndLabelsAndFeatureVals.first();
            // first index is the number of the document
            // second index is the position in the document
            // element in labels[i][j] is the index of the correct label (if it
            // exists) at position j in document i
            int[][] labels = dataAndLabelsAndFeatureVals.second();
            double[][][][] featureVals = dataAndLabelsAndFeatureVals.third();

            if (flags.loadProcessedData != null) {
                List<List<CRFDatum<Collection<String>, String>>> processedData = loadProcessedData(
                        flags.loadProcessedData);
                if (processedData != null) {
                    // enlarge the data and labels array
                    int[][][][] allData = new int[data.length + processedData.size()][][][];
                    double[][][][] allFeatureVals = new double[featureVals.length + processedData.size()][][][];
                    int[][] allLabels = new int[labels.length + processedData.size()][];
                    System.arraycopy(data, 0, allData, 0, data.length);
                    System.arraycopy(labels, 0, allLabels, 0, labels.length);
                    System.arraycopy(featureVals, 0, allFeatureVals, 0, featureVals.length);
                    // add to the data and labels array
                    addProcessedData(processedData, allData, allLabels, allFeatureVals, data.length);
                    data = allData;
                    labels = allLabels;
                    featureVals = allFeatureVals;
                }
            }

            double[] oneDimWeights = trainWeights(data, labels, evaluators, i, featureVals);
            if (oneDimWeights != null) {
                this.weights = to2D(oneDimWeights, labelIndices, map);
            }

            // if (flags.useFloat) {
            //   oneDimWeights = trainWeightsUsingFloatCRF(data, labels, evaluators, i, featureVals);
            // } else if (flags.numLopExpert > 1) {
            //   oneDimWeights = trainWeightsUsingLopCRF(data, labels, evaluators, i, featureVals);
            // } else {
            //   oneDimWeights = trainWeightsUsingDoubleCRF(data, labels, evaluators, i, featureVals);
            // }

            // save feature index to disk and read in later
            if (flags.saveFeatureIndexToDisk) {
                try {
                    log.info("Reading temporary feature index file.");
                    featureIndex = IOUtils.readObjectFromFile(featIndexFile);
                } catch (Exception e) {
                    throw new RuntimeException("Could not open temporary feature index file for reading.");
                }
            }

            if (i != flags.numTimesPruneFeatures) {
                dropFeaturesBelowThreshold(flags.featureDiffThresh);
                log.info("Removing features with weight below " + flags.featureDiffThresh + " and retraining...");
            }
        }
    }

    public static double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) {
        double[][] newWeights = new double[map.length][];
        int index = 0;
        for (int i = 0; i < map.length; i++) {
            newWeights[i] = new double[labelIndices.get(map[i]).size()];
            System.arraycopy(weights, index, newWeights[i], 0, labelIndices.get(map[i]).size());
            index += labelIndices.get(map[i]).size();
        }
        return newWeights;
    }

    protected void pruneNodeFeatureIndices(int totalNumOfFeatureSlices, int numOfFeatureSlices) {
        int numOfNodeFeatures = nodeFeatureIndicesMap.size();
        int beginIndex = 0;
        int endIndex = Math.min((int) (numOfNodeFeatures / (totalNumOfFeatureSlices + 0.0) * numOfFeatureSlices),
                numOfNodeFeatures);
        List<Integer> nodeFeatureOriginalIndices = nodeFeatureIndicesMap.objectsList();
        List<Integer> edgeFeatureOriginalIndices = edgeFeatureIndicesMap.objectsList();

        Index<Integer> newNodeFeatureIndex = new HashIndex<>();
        Index<Integer> newEdgeFeatureIndex = new HashIndex<>();
        Index<String> newFeatureIndex = new HashIndex<>();

        for (int i = beginIndex; i < endIndex; i++) {
            int oldIndex = nodeFeatureOriginalIndices.get(i);
            String f = featureIndex.get(oldIndex);
            int index = newFeatureIndex.addToIndex(f);
            newNodeFeatureIndex.add(index);
        }
        for (Integer edgeFIndex : edgeFeatureOriginalIndices) {
            String f = featureIndex.get(edgeFIndex);
            int index = newFeatureIndex.addToIndex(f);
            newEdgeFeatureIndex.add(index);
        }

        nodeFeatureIndicesMap = newNodeFeatureIndex;
        edgeFeatureIndicesMap = newEdgeFeatureIndex;

        int[] newMap = new int[newFeatureIndex.size()];
        for (int i = 0; i < newMap.length; i++) {
            int index = featureIndex.indexOf(newFeatureIndex.get(i));
            newMap[i] = map[index];
        }
        map = newMap;

        featureIndex = newFeatureIndex;
    }

    protected CRFLogConditionalObjectiveFunction getObjectiveFunction(int[][][][] data, int[][] labels) {
        return new CRFLogConditionalObjectiveFunction(data, labels, windowSize, classIndex, labelIndices, map,
                flags.priorType, flags.backgroundSymbol, flags.sigma, null, flags.multiThreadGrad);
    }

    protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr,
            double[][][][] featureVals) {

        CRFLogConditionalObjectiveFunction func = getObjectiveFunction(data, labels);
        cliquePotentialFunctionHelper = func;

        // create feature grouping
        // todo [cdm 2016]: Use a CollectionValuedMap
        Map<String, Set<Integer>> featureSets = null;
        if (flags.groupByOutputClass) {
            featureSets = new HashMap<>();
            if (flags.groupByFeatureTemplate) {
                int pIndex = 0;
                for (int fIndex = 0; fIndex < map.length; fIndex++) {
                    int cliqueType = map[fIndex];
                    int numCliqueTypeOutputClass = labelIndices.get(map[fIndex]).size();
                    for (int cliqueOutClass = 0; cliqueOutClass < numCliqueTypeOutputClass; cliqueOutClass++) {
                        String name = "c:" + cliqueType + "-o:" + cliqueOutClass + "-g:"
                                + featureIndexToTemplateIndex.get(fIndex);
                        if (featureSets.containsKey(name)) {
                            featureSets.get(name).add(pIndex);
                        } else {
                            Set<Integer> newSet = new HashSet<>();
                            newSet.add(pIndex);
                            featureSets.put(name, newSet);
                        }
                        pIndex++;
                    }
                }
            } else {
                int pIndex = 0;
                for (int cliqueType : map) {
                    int numCliqueTypeOutputClass = labelIndices.get(cliqueType).size();
                    for (int cliqueOutClass = 0; cliqueOutClass < numCliqueTypeOutputClass; cliqueOutClass++) {
                        String name = "c:" + cliqueType + "-o:" + cliqueOutClass;
                        if (featureSets.containsKey(name)) {
                            featureSets.get(name).add(pIndex);
                        } else {
                            Set<Integer> newSet = new HashSet<>();
                            newSet.add(pIndex);
                            featureSets.put(name, newSet);
                        }
                        pIndex++;
                    }
                }
            }
        } else if (flags.groupByFeatureTemplate) {
            featureSets = new HashMap<>();
            int pIndex = 0;
            for (int fIndex = 0; fIndex < map.length; fIndex++) {
                int cliqueType = map[fIndex];
                int numCliqueTypeOutputClass = labelIndices.get(map[fIndex]).size();
                for (int cliqueOutClass = 0; cliqueOutClass < numCliqueTypeOutputClass; cliqueOutClass++) {
                    String name = "c:" + cliqueType + "-g:" + featureIndexToTemplateIndex.get(fIndex);
                    if (featureSets.containsKey(name)) {
                        featureSets.get(name).add(pIndex);
                    } else {
                        Set<Integer> newSet = new HashSet<>();
                        newSet.add(pIndex);
                        featureSets.put(name, newSet);
                    }
                    pIndex++;
                }
            }
        }
        if (featureSets != null) {
            int[][] fg = new int[featureSets.size()][];
            log.info("After feature grouping, total of " + fg.length + " groups");
            int count = 0;
            for (Set<Integer> aSet : featureSets.values()) {
                fg[count] = new int[aSet.size()];
                int i = 0;
                for (Integer val : aSet)
                    fg[count][i++] = val;
                count++;
            }
            func.setFeatureGrouping(fg);
        }

        Minimizer<DiffFunction> minimizer = getMinimizer(pruneFeatureItr, evaluators);

        double[] initialWeights;
        if (flags.initialWeights == null) {
            initialWeights = func.initial();
        } else {
            try {
                log.info("Reading initial weights from file " + flags.initialWeights);
                DataInputStream dis = IOUtils.getDataInputStream(flags.initialWeights);
                initialWeights = ConvertByteArray.readDoubleArr(dis);
            } catch (IOException e) {
                throw new RuntimeException(
                        "Could not read from double initial weight file " + flags.initialWeights);
            }
        }
        log.info("numWeights: " + initialWeights.length);

        if (flags.testObjFunction) {
            StochasticDiffFunctionTester tester = new StochasticDiffFunctionTester(func);
            if (tester.testSumOfBatches(initialWeights, 1e-4)) {
                log.info("Successfully tested stochastic objective function.");
            } else {
                throw new IllegalStateException("Testing of stochastic objective function failed.");
            }

        }
        //check gradient
        if (flags.checkGradient) {
            if (func.gradientCheck()) {
                log.info("gradient check passed");
            } else {
                throw new RuntimeException("gradient check failed");
            }
        }
        return minimizer.minimize(func, flags.tolerance, initialWeights);
    }

    public Minimizer<DiffFunction> getMinimizer() {
        return getMinimizer(0, null);
    }

    public Minimizer<DiffFunction> getMinimizer(int featurePruneIteration, Evaluator[] evaluators) {
        Minimizer<DiffFunction> minimizer = null;
        QNMinimizer qnMinimizer = null;

        if (flags.useQN || flags.useSGDtoQN) {
            // share code for creation of QNMinimizer
            int qnMem;
            if (featurePruneIteration == 0) {
                qnMem = flags.QNsize;
            } else {
                qnMem = flags.QNsize2;
            }

            if (flags.interimOutputFreq != 0) {
                Function monitor = new ResultStoringMonitor(flags.interimOutputFreq, flags.serializeTo);
                qnMinimizer = new QNMinimizer(monitor, qnMem, flags.useRobustQN);
            } else {
                qnMinimizer = new QNMinimizer(qnMem, flags.useRobustQN);
            }

            qnMinimizer.terminateOnMaxItr(flags.maxQNItr);
            qnMinimizer.terminateOnEvalImprovement(flags.terminateOnEvalImprovement);
            qnMinimizer.setTerminateOnEvalImprovementNumOfEpoch(flags.terminateOnEvalImprovementNumOfEpoch);
            qnMinimizer.suppressTestPrompt(flags.suppressTestDebug);
            if (flags.useOWLQN) {
                qnMinimizer.useOWLQN(flags.useOWLQN, flags.priorLambda);
            }
        }

        if (flags.useQN) {
            minimizer = qnMinimizer;
        } else if (flags.useInPlaceSGD) {
            SGDMinimizer<DiffFunction> sgdMinimizer = new SGDMinimizer<>(flags.sigma, flags.SGDPasses,
                    flags.tuneSampleSize, flags.stochasticBatchSize);
            if (flags.useSGDtoQN) {
                minimizer = new HybridMinimizer(sgdMinimizer, qnMinimizer, flags.SGDPasses);
            } else {
                minimizer = sgdMinimizer;
            }
        } else if (flags.useAdaGradFOBOS) {
            double lambda = 0.5 / (flags.sigma * flags.sigma);
            minimizer = new SGDWithAdaGradAndFOBOS<>(flags.initRate, lambda, flags.SGDPasses,
                    flags.stochasticBatchSize, flags.priorType, flags.priorAlpha, flags.useAdaDelta,
                    flags.useAdaDiff, flags.adaGradEps, flags.adaDeltaRho);
            ((SGDWithAdaGradAndFOBOS<?>) minimizer).terminateOnEvalImprovement(flags.terminateOnEvalImprovement);
            ((SGDWithAdaGradAndFOBOS<?>) minimizer).terminateOnAvgImprovement(flags.terminateOnAvgImprovement,
                    flags.tolerance);
            ((SGDWithAdaGradAndFOBOS<?>) minimizer)
                    .setTerminateOnEvalImprovementNumOfEpoch(flags.terminateOnEvalImprovementNumOfEpoch);
            ((SGDWithAdaGradAndFOBOS<?>) minimizer).suppressTestPrompt(flags.suppressTestDebug);
        } else if (flags.useSGDtoQN) {
            minimizer = new SGDToQNMinimizer(flags.initialGain, flags.stochasticBatchSize, flags.SGDPasses,
                    flags.QNPasses, flags.SGD2QNhessSamples, flags.QNsize, flags.outputIterationsToFile);
        } else if (flags.useSMD) {
            minimizer = new SMDMinimizer<>(flags.initialGain, flags.stochasticBatchSize, flags.stochasticMethod,
                    flags.SGDPasses);
        } else if (flags.useSGD) {
            minimizer = new InefficientSGDMinimizer<>(flags.initialGain, flags.stochasticBatchSize);
        } else if (flags.useScaledSGD) {
            minimizer = new ScaledSGDMinimizer(flags.initialGain, flags.stochasticBatchSize, flags.SGDPasses,
                    flags.scaledSGDMethod);
        } else if (flags.l1reg > 0.0) {
            minimizer = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer",
                    flags.l1reg);
        } else {
            throw new RuntimeException("No minimizer assigned!");
        }

        if (minimizer instanceof HasEvaluators) {
            if (minimizer instanceof QNMinimizer) {
                ((QNMinimizer) minimizer).setEvaluators(flags.evaluateIters, flags.startEvaluateIters, evaluators);
            } else
                ((HasEvaluators) minimizer).setEvaluators(flags.evaluateIters, evaluators);
        }

        return minimizer;
    }

    /**
     * Creates a new CRFDatum from the preprocessed allData format, given the
     * document number, position number, and a List of Object labels.
     *
     * @return A new CRFDatum
     */
    protected List<CRFDatum<? extends Collection<String>, ? extends CharSequence>> extractDatumSequence(
            int[][][] allData, int beginPosition, int endPosition, List<IN> labeledWordInfos) {
        List<CRFDatum<? extends Collection<String>, ? extends CharSequence>> result = new ArrayList<>();
        int beginContext = beginPosition - windowSize + 1;
        if (beginContext < 0) {
            beginContext = 0;
        }
        // for the beginning context, add some dummy datums with no features!
        // TODO: is there any better way to do this?
        for (int position = beginContext; position < beginPosition; position++) {
            List<Collection<String>> cliqueFeatures = new ArrayList<>();
            List<double[]> featureVals = new ArrayList<>();
            for (int i = 0; i < windowSize; i++) {
                // create a feature list
                cliqueFeatures.add(Collections.emptyList());
                featureVals.add(null);
            }
            CRFDatum<Collection<String>, String> datum = new CRFDatum<>(cliqueFeatures,
                    labeledWordInfos.get(position).get(CoreAnnotations.AnswerAnnotation.class), featureVals);
            result.add(datum);
        }
        // now add the real datums
        for (int position = beginPosition; position <= endPosition; position++) {
            List<Collection<String>> cliqueFeatures = new ArrayList<>();
            List<double[]> featureVals = new ArrayList<>();
            for (int i = 0; i < windowSize; i++) {
                // create a feature list
                Collection<String> features = new ArrayList<>();
                for (int j = 0; j < allData[position][i].length; j++) {
                    features.add(featureIndex.get(allData[position][i][j]));
                }
                cliqueFeatures.add(features);
                featureVals.add(null);
            }
            CRFDatum<Collection<String>, String> datum = new CRFDatum<>(cliqueFeatures,
                    labeledWordInfos.get(position).get(CoreAnnotations.AnswerAnnotation.class), featureVals);
            result.add(datum);
        }
        return result;
    }

    /**
     * Adds the List of Lists of CRFDatums to the data and labels arrays, treating
     * each datum as if it were its own document. Adds context labels in addition
     * to the target label for each datum, meaning that for a particular document,
     * the number of labels will be windowSize-1 greater than the number of
     * datums.
     *
     * @param processedData A List of Lists of CRFDatums
     */
    protected void addProcessedData(List<List<CRFDatum<Collection<String>, String>>> processedData,
            int[][][][] data, int[][] labels, double[][][][] featureVals, int offset) {
        for (int i = 0, pdSize = processedData.size(); i < pdSize; i++) {
            int dataIndex = i + offset;
            List<CRFDatum<Collection<String>, String>> document = processedData.get(i);
            int dsize = document.size();
            labels[dataIndex] = new int[dsize];
            data[dataIndex] = new int[dsize][][];
            if (featureVals != null)
                featureVals[dataIndex] = new double[dsize][][];
            for (int j = 0; j < dsize; j++) {
                CRFDatum<Collection<String>, String> crfDatum = document.get(j);
                // add label, they are offset by extra context
                labels[dataIndex][j] = classIndex.indexOf(crfDatum.label());
                // add featureVals
                List<double[]> featureValList = featureVals != null ? crfDatum.asFeatureVals() : null;
                // add features
                List<Collection<String>> cliques = crfDatum.asFeatures();
                int csize = cliques.size();
                data[dataIndex][j] = new int[csize][];
                if (featureVals != null)
                    featureVals[dataIndex][j] = new double[csize][];
                for (int k = 0; k < csize; k++) {
                    Collection<String> features = cliques.get(k);

                    data[dataIndex][j][k] = new int[features.size()];
                    if (featureVals != null && k < featureValList.size())
                        featureVals[dataIndex][j][k] = featureValList.get(k);

                    int m = 0;
                    try {
                        for (String feature : features) {
                            // log.info("feature " + feature);
                            // if (featureIndex.indexOf(feature)) ;
                            if (featureIndex == null) {
                                System.out.println("Feature is NULL!");
                            }
                            data[dataIndex][j][k][m] = featureIndex.indexOf(feature);
                            m++;
                        }
                    } catch (Exception e) {
                        log.error("Add processed data failed.", e);
                        log.info(String.format("[index=%d, j=%d, k=%d, m=%d]%n", dataIndex, j, k, m));
                        log.info("data.length                    " + data.length);
                        log.info("data[dataIndex].length         " + data[dataIndex].length);
                        log.info("data[dataIndex][j].length      " + data[dataIndex][j].length);
                        log.info("data[dataIndex][j][k].length   " + data[dataIndex][j].length);
                        log.info("data[dataIndex][j][k][m]       " + data[dataIndex][j][k][m]);
                        return;
                    }
                }
            }
        }
    }

    protected static void saveProcessedData(List<?> datums, String filename) {
        log.info("Saving processed data of size " + datums.size() + " to serialized file...");
        ObjectOutputStream oos = null;
        try {
            oos = new ObjectOutputStream(new FileOutputStream(filename));
            oos.writeObject(datums);
        } catch (IOException e) {
            // do nothing
        } finally {
            IOUtils.closeIgnoringExceptions(oos);
        }
        log.info("done.");
    }

    protected static List<List<CRFDatum<Collection<String>, String>>> loadProcessedData(String filename) {
        List<List<CRFDatum<Collection<String>, String>>> result;
        try {
            result = IOUtils.readObjectFromURLOrClasspathOrFileSystem(filename);
        } catch (Exception e) {
            log.warn(e);
            result = Collections.emptyList();
        }
        log.info("Loading processed data from serialized file ... done. Got " + result.size() + " datums.");
        return result;
    }

    protected void loadTextClassifier(BufferedReader br) throws Exception {
        String line = br.readLine();
        // first line should be this format:
        // labelIndices.size()=\t%d
        String[] toks = line.split("\\t");
        if (!toks[0].equals("labelIndices.length=")) {
            throw new RuntimeException("format error");
        }
        int size = Integer.parseInt(toks[1]);
        labelIndices = new ArrayList<>(size);
        for (int labelIndicesIdx = 0; labelIndicesIdx < size; labelIndicesIdx++) {
            line = br.readLine();
            // first line should be this format:
            // labelIndices.length=\t%d
            // labelIndices[0].size()=\t%d
            toks = line.split("\\t");
            if (!(toks[0].startsWith("labelIndices[") && toks[0].endsWith("].size()="))) {
                throw new RuntimeException("format error");
            }
            int labelIndexSize = Integer.parseInt(toks[1]);
            labelIndices.add(new HashIndex<>());
            int count = 0;
            while (count < labelIndexSize) {
                line = br.readLine();
                toks = line.split("\\t");
                int idx = Integer.parseInt(toks[0]);
                if (count != idx) {
                    throw new RuntimeException("format error");
                }

                String[] crflabelstr = toks[1].split(" ");
                int[] crflabel = new int[crflabelstr.length];
                for (int i = 0; i < crflabelstr.length; i++) {
                    crflabel[i] = Integer.parseInt(crflabelstr[i]);
                }
                CRFLabel crfL = new CRFLabel(crflabel);

                labelIndices.get(labelIndicesIdx).add(crfL);
                count++;
            }
        }

        for (Index<CRFLabel> index : labelIndices) {
            for (int j = 0; j < index.size(); j++) {
                int[] label = index.get(j).getLabel();
                List<Integer> list = new ArrayList<>();
                for (int l : label) {
                    list.add(l);
                }
            }
        }

        line = br.readLine();
        toks = line.split("\\t");
        if (!toks[0].equals("classIndex.size()=")) {
            throw new RuntimeException("format error");
        }
        int classIndexSize = Integer.parseInt(toks[1]);
        classIndex = new HashIndex<>();
        int count = 0;
        while (count < classIndexSize) {
            line = br.readLine();
            toks = line.split("\\t");
            int idx = Integer.parseInt(toks[0]);
            if (count != idx) {
                throw new RuntimeException("format error");
            }
            classIndex.add(toks[1]);
            count++;
        }

        line = br.readLine();
        toks = line.split("\\t");
        if (!toks[0].equals("featureIndex.size()=")) {
            throw new RuntimeException("format error");
        }
        int featureIndexSize = Integer.parseInt(toks[1]);
        featureIndex = new HashIndex<>();
        count = 0;
        while (count < featureIndexSize) {
            line = br.readLine();
            toks = line.split("\\t");
            int idx = Integer.parseInt(toks[0]);
            if (count != idx) {
                throw new RuntimeException("format error");
            }
            featureIndex.add(toks[1]);
            count++;
        }

        line = br.readLine();
        if (!line.equals("<flags>")) {
            throw new RuntimeException("format error");
        }
        Properties p = new Properties();
        line = br.readLine();

        while (!line.equals("</flags>")) {
            // log.info("DEBUG: flags line: "+line);
            String[] keyValue = line.split("=");
            // System.err.printf("DEBUG: p.setProperty(%s,%s)%n", keyValue[0],
            // keyValue[1]);
            p.setProperty(keyValue[0], keyValue[1]);
            line = br.readLine();
        }

        // log.info("DEBUG: out from flags");
        flags = new SeqClassifierFlags(p);

        if (flags.useEmbedding) {
            line = br.readLine();
            toks = line.split("\\t");
            if (!toks[0].equals("embeddings.size()=")) {
                throw new RuntimeException("format error in embeddings");
            }
            int embeddingSize = Integer.parseInt(toks[1]);
            embeddings = Generics.newHashMap(embeddingSize);
            count = 0;
            while (count < embeddingSize) {
                line = br.readLine().trim();
                toks = line.split("\\t");
                String word = toks[0];
                double[] arr = ArrayUtils.toDoubleArray(toks[1].split(" "));
                embeddings.put(word, arr);
                count++;
            }
        }

        // <featureFactory>
        // edu.stanford.nlp.wordseg.Gale2007ChineseSegmenterFeatureFactory
        // </featureFactory>
        line = br.readLine();

        String[] featureFactoryName = line.split(" ");
        if (featureFactoryName.length < 2 || !featureFactoryName[0].equals("<featureFactory>")
                || !featureFactoryName[featureFactoryName.length - 1].equals("</featureFactory>")) {
            throw new RuntimeException("format error unexpected featureFactory line: " + line);
        }
        featureFactories = Generics.newArrayList();
        for (int ff = 1; ff < featureFactoryName.length - 1; ++ff) {
            FeatureFactory<IN> featureFactory = (FeatureFactory<IN>) Class.forName(featureFactoryName[1])
                    .newInstance();
            featureFactory.init(flags);
            featureFactories.add(featureFactory);
        }

        reinit();

        // <windowSize> 2 </windowSize>
        line = br.readLine();

        String[] windowSizeName = line.split(" ");
        if (!windowSizeName[0].equals("<windowSize>") || !windowSizeName[2].equals("</windowSize>")) {
            throw new RuntimeException("format error");
        }
        windowSize = Integer.parseInt(windowSizeName[1]);

        // weights.length= 2655170
        line = br.readLine();

        toks = line.split("\\t");
        if (!toks[0].equals("weights.length=")) {
            throw new RuntimeException("format error");
        }
        int weightsLength = Integer.parseInt(toks[1]);
        weights = new double[weightsLength][];
        count = 0;
        while (count < weightsLength) {
            line = br.readLine();

            toks = line.split("\\t");
            int weights2Length = Integer.parseInt(toks[0]);
            weights[count] = new double[weights2Length];
            String[] weightsValue = toks[1].split(" ");
            if (weights2Length != weightsValue.length) {
                throw new RuntimeException("weights format error");
            }

            for (int i2 = 0; i2 < weights2Length; i2++) {
                weights[count][i2] = Double.parseDouble(weightsValue[i2]);
            }
            count++;
        }
        System.err.printf("DEBUG: double[%d][] weights loaded%n", weightsLength);
        line = br.readLine();

        if (line != null) {
            throw new RuntimeException("weights format error");
        }
    }

    public void loadTextClassifier(String text, Properties props) throws ClassCastException, IOException,
            ClassNotFoundException, InstantiationException, IllegalAccessException {
        // log.info("DEBUG: in loadTextClassifier");
        log.info("Loading Text Classifier from " + text);
        try (BufferedReader br = IOUtils.readerFromString(text)) {
            loadTextClassifier(br);
        } catch (Exception ex) {
            log.info("Exception in loading text classifier from " + text, ex);
        }
    }

    protected void serializeTextClassifier(PrintWriter pw) throws Exception {
        pw.printf("labelIndices.length=\t%d%n", labelIndices.size());
        for (int i = 0; i < labelIndices.size(); i++) {
            pw.printf("labelIndices[%d].size()=\t%d%n", i, labelIndices.get(i).size());
            for (int j = 0; j < labelIndices.get(i).size(); j++) {
                int[] label = labelIndices.get(i).get(j).getLabel();
                List<Integer> list = new ArrayList<>();
                for (int l : label) {
                    list.add(l);
                }
                pw.printf("%d\t%s%n", j, StringUtils.join(list, " "));
            }
        }

        pw.printf("classIndex.size()=\t%d%n", classIndex.size());
        for (int i = 0; i < classIndex.size(); i++) {
            pw.printf("%d\t%s%n", i, classIndex.get(i));
        }
        // pw.printf("</classIndex>%n");

        pw.printf("featureIndex.size()=\t%d%n", featureIndex.size());
        for (int i = 0; i < featureIndex.size(); i++) {
            pw.printf("%d\t%s%n", i, featureIndex.get(i));
        }
        // pw.printf("</featureIndex>%n");

        pw.println("<flags>");
        pw.print(flags);
        pw.println("</flags>");

        if (flags.useEmbedding) {
            pw.printf("embeddings.size()=\t%d%n", embeddings.size());
            for (String word : embeddings.keySet()) {
                double[] arr = embeddings.get(word);
                Double[] arrUnboxed = new Double[arr.length];
                for (int i = 0; i < arr.length; i++)
                    arrUnboxed[i] = arr[i];
                pw.printf("%s\t%s%n", word, StringUtils.join(arrUnboxed, " "));
            }
        }

        pw.printf("<featureFactory>");
        for (FeatureFactory<IN> featureFactory : featureFactories) {
            pw.printf(" %s ", featureFactory.getClass().getName());
        }
        pw.printf("</featureFactory>%n");

        pw.printf("<windowSize> %d </windowSize>%n", windowSize);

        pw.printf("weights.length=\t%d%n", weights.length);
        for (double[] ws : weights) {
            ArrayList<Double> list = new ArrayList<>();
            for (double w : ws) {
                list.add(w);
            }
            pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
        }
    }

    /**
     * Serialize the model to a human readable format. It's not yet complete. It
     * should now work for Chinese segmenter though. TODO: check things in
     * serializeClassifier and add other necessary serialization back.
     *
     * @param serializePath File to write text format of classifier to.
     */
    public void serializeTextClassifier(String serializePath) {
        try {
            PrintWriter pw = new PrintWriter(new GZIPOutputStream(new FileOutputStream(serializePath)));
            serializeTextClassifier(pw);

            pw.close();
            log.info("Serializing Text classifier to " + serializePath + "... done.");
        } catch (Exception e) {
            log.info("Serializing Text classifier to " + serializePath + "... FAILED.", e);
        }
    }

    public void serializeClassIndex(String serializePath) {

        ObjectOutputStream oos = null;
        try {
            oos = IOUtils.writeStreamFromString(serializePath);
            oos.writeObject(classIndex);
            log.info("Serializing class index to " + serializePath + "... done.");
        } catch (Exception e) {
            log.info("Serializing class index to " + serializePath + "... FAILED.", e);
        } finally {
            IOUtils.closeIgnoringExceptions(oos);
        }
    }

    public static Index<String> loadClassIndexFromFile(String serializePath) {
        ObjectInputStream ois = null;
        Index<String> c = null;
        try {
            ois = IOUtils.readStreamFromString(serializePath);
            c = (Index<String>) ois.readObject();
            log.info("Reading class index from " + serializePath + "... done.");
        } catch (Exception e) {
            log.info("Reading class index from " + serializePath + "... FAILED.", e);
        } finally {
            IOUtils.closeIgnoringExceptions(ois);
        }

        return c;
    }

    public void serializeWeights(String serializePath) {
        ObjectOutputStream oos = null;
        try {
            oos = IOUtils.writeStreamFromString(serializePath);
            oos.writeObject(weights);
            log.info("Serializing weights to " + serializePath + "... done.");
        } catch (Exception e) {
            log.info("Serializing weights to " + serializePath + "... FAILED.", e);
        } finally {
            IOUtils.closeIgnoringExceptions(oos);
        }
    }

    public static double[][] loadWeightsFromFile(String serializePath) {

        ObjectInputStream ois = null;
        double[][] w = null;
        try {
            ois = IOUtils.readStreamFromString(serializePath);
            w = (double[][]) ois.readObject();
            log.info("Reading weights from " + serializePath + "... done.");
        } catch (Exception e) {
            log.info("Reading weights from " + serializePath + "... FAILED.", e);
        } finally {
            IOUtils.closeIgnoringExceptions(ois);
        }

        return w;
    }

    public void serializeFeatureIndex(String serializePath) {
        ObjectOutputStream oos = null;
        try {
            oos = IOUtils.writeStreamFromString(serializePath);
            oos.writeObject(featureIndex);
            log.info("Serializing FeatureIndex to " + serializePath + "... done.");
        } catch (Exception e) {
            log.info("Failed");
            log.info("Serializing FeatureIndex to " + serializePath + "... FAILED.", e);
        } finally {
            IOUtils.closeIgnoringExceptions(oos);
        }
    }

    public static Index<String> loadFeatureIndexFromFile(String serializePath) {
        ObjectInputStream ois = null;
        Index<String> f = null;
        try {
            ois = IOUtils.readStreamFromString(serializePath);
            f = (Index<String>) ois.readObject();
            log.info("Reading FeatureIndex from " + serializePath + "... done.");
        } catch (Exception e) {
            log.info("Reading FeatureIndex from " + serializePath + "... FAILED.", e);
        } finally {
            IOUtils.closeIgnoringExceptions(ois);
        }

        return f;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void serializeClassifier(String serializePath) {
        ObjectOutputStream oos = null;
        try {
            oos = IOUtils.writeStreamFromString(serializePath);
            serializeClassifier(oos);
            log.info("Serializing classifier to " + serializePath + "... done.");

        } catch (Exception e) {
            throw new RuntimeIOException("Serializing classifier to " + serializePath + "... FAILED", e);
        } finally {
            IOUtils.closeIgnoringExceptions(oos);
        }
    }

    /**
     * Serialize the classifier to the given ObjectOutputStream.
     * <br>
     * (Since the classifier is a processor, we don't want to serialize the
     * whole classifier but just the data that represents a classifier model.)
     */
    @Override
    public void serializeClassifier(ObjectOutputStream oos) {
        try {
            oos.writeObject(labelIndices);
            oos.writeObject(classIndex);
            oos.writeObject(featureIndex);
            oos.writeObject(flags);
            if (flags.useEmbedding) {
                oos.writeObject(embeddings);
            }
            // For some reason, writing out the array of FeatureFactory
            // objects doesn't seem to work.  The resulting classifier
            // doesn't have the lexicon (distsim object) correctly saved.  So now custom write the list
            oos.writeObject(featureFactories.size());
            for (FeatureFactory<IN> ff : featureFactories) {
                oos.writeObject(ff);
            }
            oos.writeInt(windowSize);
            oos.writeObject(weights);
            // oos.writeObject(WordShapeClassifier.getKnownLowerCaseWords());

            oos.writeObject(knownLCWords);
            if (labelDictionary != null) {
                oos.writeObject(labelDictionary);
            }
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    /**
     * Loads a classifier from the specified InputStream. This version works
     * quietly (unless VERBOSE is true). If props is non-null then any properties
     * it specifies override those in the serialized file. However, only some
     * properties are sensible to change (you shouldn't change how features are
     * defined).
     * <p>
     * <i>Note:</i> This method does not close the ObjectInputStream. (But earlier
     * versions of the code used to, so beware....)
     */
    @Override
    @SuppressWarnings({ "unchecked" })
    // can't have right types in deserialization
    public void loadClassifier(ObjectInputStream ois, Properties props)
            throws ClassCastException, IOException, ClassNotFoundException {
        Object o = ois.readObject();
        // TODO: when we next break serialization, get rid of this fork and only read the List<Index> (i.e., keep first case)
        if (o instanceof List) {
            labelIndices = (List<Index<CRFLabel>>) o;
        } else {
            Index<CRFLabel>[] indexArray = (Index<CRFLabel>[]) o;
            labelIndices = new ArrayList<>(indexArray.length);
            Collections.addAll(labelIndices, indexArray);
        }
        classIndex = (Index<String>) ois.readObject();
        featureIndex = (Index<String>) ois.readObject();
        flags = (SeqClassifierFlags) ois.readObject();
        if (flags.useEmbedding) {
            embeddings = (Map<String, double[]>) ois.readObject();
        }
        Object featureFactory = ois.readObject();
        if (featureFactory instanceof List) {
            featureFactories = ErasureUtils.uncheckedCast(featureFactories);
            //      int i = 0;
            //      for (FeatureFactory ff : featureFactories) { // XXXX
            //        System.err.println("List FF #" + i + ": " + ((NERFeatureFactory) ff).describeDistsimLexicon()); // XXXX
            //        i++;
            //      }
        } else if (featureFactory instanceof FeatureFactory) {
            featureFactories = Generics.newArrayList();
            featureFactories.add((FeatureFactory<IN>) featureFactory);
            //      System.err.println(((NERFeatureFactory) featureFactory).describeDistsimLexicon()); // XXXX
        } else if (featureFactory instanceof Integer) {
            // this is the current format (2014) since writing list didn't work (see note in serializeClassifier).
            int size = (Integer) featureFactory;
            featureFactories = Generics.newArrayList(size);
            for (int i = 0; i < size; ++i) {
                featureFactory = ois.readObject();
                if (!(featureFactory instanceof FeatureFactory)) {
                    throw new RuntimeIOException("Should have FeatureFactory but got " + featureFactory.getClass());
                }
                //        System.err.println("FF #" + i + ": " + ((NERFeatureFactory) featureFactory).describeDistsimLexicon()); // XXXX
                featureFactories.add((FeatureFactory<IN>) featureFactory);
            }
        }

        // log.info("properties passed into CRF's loadClassifier are:" + props);
        if (props != null) {
            flags.setProperties(props, false);
        }

        windowSize = ois.readInt();
        weights = (double[][]) ois.readObject();

        // WordShapeClassifier.setKnownLowerCaseWords((Set) ois.readObject());
        Set<String> lcWords = (Set<String>) ois.readObject();
        if (lcWords instanceof MaxSizeConcurrentHashSet) {
            knownLCWords = (MaxSizeConcurrentHashSet<String>) lcWords;
        } else {
            knownLCWords = new MaxSizeConcurrentHashSet<>(lcWords);
        }

        reinit();

        if (flags.labelDictionaryCutoff > 0) {
            labelDictionary = (LabelDictionary) ois.readObject();
        }

        if (VERBOSE) {
            log.info("windowSize=" + windowSize);
            log.info("flags=\n" + flags);
        }
    }

    /**
     * This is used to load the default supplied classifier stored within the jar
     * file. THIS FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE
     * WHICH HAS A SERIALIZED CLASSIFIER STORED INSIDE IT.
     */
    public void loadDefaultClassifier() {
        loadClassifierNoExceptions(DEFAULT_CLASSIFIER);
    }

    public void loadTagIndex() {
        if (tagIndex == null) {
            tagIndex = new HashIndex<>();
            for (String tag : classIndex.objectsList()) {
                String[] parts = tag.split("-");
                // if (parts.length > 1)
                tagIndex.add(parts[parts.length - 1]);
            }
            tagIndex.add(flags.backgroundSymbol);
        }
        if (flags.useNERPriorBIO) {
            if (entityMatrices == null)
                entityMatrices = readEntityMatrices(flags.entityMatrix, tagIndex);
        }
    }

    private static double[][] parseMatrix(String[] lines, Index<String> tagIndex, int matrixSize, boolean smooth) {
        return parseMatrix(lines, tagIndex, matrixSize, smooth, true);
    }

    /**
     * @return a matrix where each entry m[i][j] is logP(j|i)
     * in other words, each row vector is normalized log conditional likelihood
     */
    static double[][] parseMatrix(String[] lines, Index<String> tagIndex, int matrixSize, boolean smooth,
            boolean useLogProb) {
        double[][] matrix = new double[matrixSize][matrixSize];
        for (int i = 0; i < matrix.length; i++) {
            matrix[i] = new double[matrixSize];
        }
        for (String line : lines) {
            String[] parts = line.split("\t");
            for (String part : parts) {
                String[] subparts = part.split(" ");
                String[] subsubparts = subparts[0].split(":");
                double counts = Double.parseDouble(subparts[1]);
                if (counts == 0.0 && smooth) // smoothing
                    counts = 1.0;
                int tagIndex1 = tagIndex.indexOf(subsubparts[0]);
                int tagIndex2 = tagIndex.indexOf(subsubparts[1]);
                matrix[tagIndex1][tagIndex2] = counts;
            }
        }
        for (int i = 0; i < matrix.length; i++) {
            double sum = ArrayMath.sum(matrix[i]);
            for (int j = 0; j < matrix[i].length; j++) {
                // log conditional probability
                if (useLogProb)
                    matrix[i][j] = Math.log(matrix[i][j] / sum);
                else
                    matrix[i][j] = matrix[i][j] / sum;
            }
        }
        return matrix;
    }

    static Pair<double[][], double[][]> readEntityMatrices(String fileName, Index<String> tagIndex) {
        int numTags = tagIndex.size();
        int matrixSize = numTags - 1;

        String[] matrixLines = new String[matrixSize];
        String[] subMatrixLines = new String[matrixSize];
        try (BufferedReader br = IOUtils.readerFromString(fileName)) {
            int lineCount = 0;
            for (String line; (line = br.readLine()) != null;) {
                line = line.trim();
                if (lineCount < matrixSize)
                    matrixLines[lineCount] = line;
                else
                    subMatrixLines[lineCount - matrixSize] = line;
                lineCount++;
            }
        } catch (Exception ex) {
            throw new RuntimeIOException(ex);
        }

        double[][] matrix = parseMatrix(matrixLines, tagIndex, matrixSize, true);
        double[][] subMatrix = parseMatrix(subMatrixLines, tagIndex, matrixSize, true);

        // In Jenny's paper, use the square root of non-log prob for matrix, but not for subMatrix
        for (int i = 0; i < matrix.length; i++) {
            for (int j = 0; j < matrix[i].length; j++)
                matrix[i][j] = matrix[i][j] / 2;
        }

        log.info("Matrix: ");
        log.info(ArrayUtils.toString(matrix));
        log.info("SubMatrix: ");
        log.info(ArrayUtils.toString(subMatrix));

        return new Pair<>(matrix, subMatrix);
    }

    public void writeWeights(PrintStream p) {
        for (String feature : featureIndex) {
            int index = featureIndex.indexOf(feature);
            // line.add(feature+"["+(-p)+"]");
            // rowHeaders.add(feature + '[' + (-p) + ']');
            double[] v = weights[index];
            Index<CRFLabel> l = this.labelIndices.get(0);
            p.println(feature + "\t\t");
            for (CRFLabel label : l) {
                p.print(label.toString(classIndex) + ':' + v[l.indexOf(label)] + '\t');
            }
            p.println();

        }
    }

    public Map<String, Counter<String>> topWeights() {
        Map<String, Counter<String>> w = new HashMap<>();
        for (String feature : featureIndex) {
            int index = featureIndex.indexOf(feature);
            // line.add(feature+"["+(-p)+"]");
            // rowHeaders.add(feature + '[' + (-p) + ']');
            double[] v = weights[index];
            Index<CRFLabel> l = this.labelIndices.get(0);
            for (CRFLabel label : l) {
                if (!w.containsKey(label.toString(classIndex)))
                    w.put(label.toString(classIndex), new ClassicCounter<>());
                w.get(label.toString(classIndex)).setCount(feature, v[l.indexOf(label)]);
            }
        }
        return w;
    }

    /** Read real-valued vector embeddings for (lowercased) word tokens.
     *  A lexicon is contained in the file flags.embeddingWords.
     *  The word vectors are then in the same order in the file flags.embeddingVectors.
     *
     *  @throws IOException If embedding vectors canot be loaded
     */
    private void readEmbeddingsData() throws IOException {
        System.err.printf("Reading embedding files %s and %s.%n", flags.embeddingWords, flags.embeddingVectors);
        List<String> wordList = new ArrayList<>();
        try (BufferedReader br = IOUtils.readerFromString(flags.embeddingWords)) {

            for (String line; (line = br.readLine()) != null;) {
                wordList.add(line.trim());
            }
            log.info("Found a dictionary of size " + wordList.size());
        }

        embeddings = Generics.newHashMap();
        try (BufferedReader br = IOUtils.readerFromString(flags.embeddingVectors)) {
            int count = 0;
            int vectorSize = -1;
            boolean warned = false;
            for (String line; (line = br.readLine()) != null;) {
                double[] vector = ArrayUtils.toDoubleArray(line.trim().split(" "));
                if (vectorSize < 0) {
                    vectorSize = vector.length;
                } else {
                    if (vectorSize != vector.length && !warned) {
                        log.info("Inconsistent vector lengths: " + vectorSize + " vs. " + vector.length);
                        warned = true;
                    }
                }
                embeddings.put(wordList.get(count++), vector);
            }
            log.info("Found " + count + " matching embeddings of dimension " + vectorSize);
        }
    }

    @Override
    public List<IN> classifyWithGlobalInformation(List<IN> tokenSeq, final CoreMap doc, final CoreMap sent) {
        return classify(tokenSeq);
    }

    /**
     * This is used to load the default supplied classifier stored within the jar
     * file. THIS FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE
     * WHICH HAS A SERIALIZED CLASSIFIER STORED INSIDE IT.
     */
    public void loadDefaultClassifier(Properties props) {
        loadClassifierNoExceptions(DEFAULT_CLASSIFIER, props);
    }

    /**
     * Used to get the default supplied classifier inside the jar file. THIS
     * FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE WHICH HAS A
     * SERIALIZED CLASSIFIER STORED INSIDE IT.
     *
     * @return The default CRFClassifier in the jar file (if there is one)
     */
    public static <INN extends CoreMap> CRFClassifier<INN> getDefaultClassifier() {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadDefaultClassifier();
        return crf;
    }

    /**
     * Used to get the default supplied classifier inside the jar file. THIS
     * FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE WHICH HAS A
     * SERIALIZED CLASSIFIER STORED INSIDE IT.
     *
     * @return The default CRFClassifier in the jar file (if there is one)
     */
    public static <INN extends CoreMap> CRFClassifier<INN> getDefaultClassifier(Properties props) {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadDefaultClassifier(props);
        return crf;
    }

    /**
     * Loads a CRF classifier from a filepath, and returns it.
     *
     * @param file File to load classifier from
     * @return The CRF classifier
     *
     * @throws IOException If there are problems accessing the input stream
     * @throws ClassCastException If there are problems interpreting the serialized data
     * @throws ClassNotFoundException If there are problems interpreting the serialized data
     */
    public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(File file)
            throws IOException, ClassCastException, ClassNotFoundException {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadClassifier(file);
        return crf;
    }

    /**
     * Loads a CRF classifier from an InputStream, and returns it. This method
     * does not buffer the InputStream, so you should have buffered it before
     * calling this method.
     *
     * @param in InputStream to load classifier from
     * @return The CRF classifier
     *
     * @throws IOException If there are problems accessing the input stream
     * @throws ClassCastException If there are problems interpreting the serialized data
     * @throws ClassNotFoundException If there are problems interpreting the serialized data
     */
    public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(InputStream in)
            throws IOException, ClassCastException, ClassNotFoundException {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadClassifier(in);
        return crf;
    }

    // new method for getting a CRFClassifier from an ObjectInputStream
    public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(ObjectInputStream ois)
            throws IOException, ClassCastException, ClassNotFoundException {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadClassifier(ois, null);
        return crf;
    }

    public static <INN extends CoreMap> CRFClassifier<INN> getClassifierNoExceptions(String loadPath) {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadClassifierNoExceptions(loadPath);
        return crf;
    }

    public static CRFClassifier<CoreLabel> getClassifier(String loadPath)
            throws IOException, ClassCastException, ClassNotFoundException {
        CRFClassifier<CoreLabel> crf = new CRFClassifier<>();
        crf.loadClassifier(loadPath);
        return crf;
    }

    public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(String loadPath, Properties props)
            throws IOException, ClassCastException, ClassNotFoundException {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadClassifier(loadPath, props);
        return crf;
    }

    public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(ObjectInputStream ois, Properties props)
            throws IOException, ClassCastException, ClassNotFoundException {
        CRFClassifier<INN> crf = new CRFClassifier<>();
        crf.loadClassifier(ois, props);
        return crf;
    }

    private static CRFClassifier<CoreLabel> chooseCRFClassifier(SeqClassifierFlags flags) {
        CRFClassifier<CoreLabel> crf; // initialized in if/else
        if (flags.useFloat) {
            crf = new CRFClassifierFloat<>(flags);
        } else if (flags.nonLinearCRF) {
            crf = new CRFClassifierNonlinear<>(flags);
        } else if (flags.numLopExpert > 1) {
            crf = new CRFClassifierWithLOP<>(flags);
        } else if (flags.priorType.equals("DROPOUT")) {
            crf = new CRFClassifierWithDropout<>(flags);
        } else if (flags.useNoisyLabel) {
            crf = new CRFClassifierNoisyLabel<>(flags);
        } else {
            crf = new CRFClassifier<>(flags);
        }
        return crf;
    }

    /** The main method. See the class documentation. */
    public static void main(String[] args) throws Exception {
        StringUtils.logInvocationString(log, args);

        Properties props = StringUtils.argsToProperties(args);
        SeqClassifierFlags flags = new SeqClassifierFlags(props);
        CRFClassifier<CoreLabel> crf = chooseCRFClassifier(flags);
        String testFile = flags.testFile;
        String testFiles = flags.testFiles;
        String textFile = flags.textFile;
        String textFiles = flags.textFiles;
        String loadPath = flags.loadClassifier;
        String loadTextPath = flags.loadTextClassifier;
        String serializeTo = flags.serializeTo;
        String serializeToText = flags.serializeToText;

        if (crf.flags.useEmbedding && crf.flags.embeddingWords != null && crf.flags.embeddingVectors != null) {
            crf.readEmbeddingsData();
        }

        if (crf.flags.loadClassIndexFrom != null) {
            crf.classIndex = loadClassIndexFromFile(crf.flags.loadClassIndexFrom);
        }

        if (loadPath != null) {
            crf.loadClassifierNoExceptions(loadPath, props);
        } else if (loadTextPath != null) {
            log.info("Warning: this is now only tested for Chinese Segmenter");
            log.info("(Sun Dec 23 00:59:39 2007) (pichuan)");
            try {
                crf.loadTextClassifier(loadTextPath, props);
                // log.info("DEBUG: out from crf.loadTextClassifier");
            } catch (Exception e) {
                throw new RuntimeException("error loading " + loadTextPath, e);
            }
        } else if (crf.flags.loadJarClassifier != null) {
            // legacy option support
            crf.loadClassifierNoExceptions(crf.flags.loadJarClassifier, props);
        } else if (crf.flags.trainFile != null || crf.flags.trainFileList != null) {
            Timing timing = new Timing();
            // temporarily unlimited size of knownLCWords
            int knownLCWordsLimit = crf.knownLCWords.getMaxSize();
            crf.knownLCWords.setMaxSize(-1);
            crf.train();
            crf.knownLCWords.setMaxSize(knownLCWordsLimit);
            timing.done(log, "CRFClassifier training");
        } else {
            crf.loadDefaultClassifier();
        }

        crf.loadTagIndex();

        if (serializeTo != null) {
            crf.serializeClassifier(serializeTo);
        }

        if (crf.flags.serializeWeightsTo != null) {
            crf.serializeWeights(crf.flags.serializeWeightsTo);
        }

        if (crf.flags.serializeFeatureIndexTo != null) {
            crf.serializeFeatureIndex(crf.flags.serializeFeatureIndexTo);
        }

        if (serializeToText != null) {
            crf.serializeTextClassifier(serializeToText);
        }

        if (testFile != null) {
            // todo: Change testFile to call testFiles with a singleton list
            DocumentReaderAndWriter<CoreLabel> readerAndWriter = crf.defaultReaderAndWriter();
            if (crf.flags.searchGraphPrefix != null) {
                crf.classifyAndWriteViterbiSearchGraph(testFile, crf.flags.searchGraphPrefix, readerAndWriter);
            } else if (crf.flags.printFirstOrderProbs) {
                crf.printFirstOrderProbs(testFile, readerAndWriter);
            } else if (crf.flags.printFactorTable) {
                crf.printFactorTable(testFile, readerAndWriter);
            } else if (crf.flags.printProbs) {
                crf.printProbs(testFile, readerAndWriter);
            } else if (crf.flags.useKBest) {
                int k = crf.flags.kBest;
                crf.classifyAndWriteAnswersKBest(testFile, k, readerAndWriter);
            } else if (crf.flags.printLabelValue) {
                crf.printLabelInformation(testFile, readerAndWriter);
            } else {
                crf.classifyAndWriteAnswers(testFile, readerAndWriter, true);
            }
        }

        if (testFiles != null) {
            List<File> files = Arrays.stream(testFiles.split(",")).map(File::new).collect(Collectors.toList());
            if (crf.flags.printProbs) {
                crf.printProbs(files, crf.defaultReaderAndWriter());
            } else {
                crf.classifyFilesAndWriteAnswers(files, crf.defaultReaderAndWriter(), true);
            }
        }

        if (textFile != null) {
            crf.classifyAndWriteAnswers(textFile, crf.plainTextReaderAndWriter(), false);
        }

        if (textFiles != null) {
            List<File> files = Arrays.stream(textFiles.split(",")).map(File::new).collect(Collectors.toList());
            crf.classifyFilesAndWriteAnswers(files);
        }

        if (crf.flags.readStdin) {
            crf.classifyStdin();
        }
    } // end main

} // end class CRFClassifier