com.scaleunlimited.classify.model.RawFeaturesLibLinearModel.java Source code

Java tutorial

Introduction

Here is the source code for com.scaleunlimited.classify.model.RawFeaturesLibLinearModel.java

Source

/**
 * Copyright (c) 2013-2015 Scale Unlimited, Inc.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.scaleunlimited.classify.model;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.scaleunlimited.classify.datum.DocDatum;
import com.scaleunlimited.classify.datum.TermsDatum;
import com.scaleunlimited.classify.vectors.VectorUtils;

import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.Train;

/**
 * This is an improved version of LibLinearModel.
 *
 */
@SuppressWarnings("serial")
public class RawFeaturesLibLinearModel extends BaseLibLinearModel {
    private static final Logger LOGGER = LoggerFactory.getLogger(RawFeaturesLibLinearModel.class);

    // Data we need to save to recreate the model
    private List<String> _uniqueTerms;

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);

        _uniqueTerms = readStrings(in);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);

        writeStrings(out, _uniqueTerms);
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = super.hashCode();
        result = prime * result + ((_uniqueTerms == null) ? 0 : _uniqueTerms.hashCode());
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (!super.equals(obj))
            return false;
        if (getClass() != obj.getClass())
            return false;
        RawFeaturesLibLinearModel other = (RawFeaturesLibLinearModel) obj;
        if (_uniqueTerms == null) {
            if (other._uniqueTerms != null)
                return false;
        } else if (!_uniqueTerms.equals(other._uniqueTerms))
            return false;
        return true;
    }

    @Override
    public double train(boolean doCrossValidation) {

        _uniqueTerms = buildUniqueTerms(_featuresList);
        List<Vector> vectors = new ArrayList<Vector>(_featuresList.size());
        for (Map<String, Integer> termMap : _featuresList) {
            vectors.add(makeNormalizedVector(termMap));
        }

        _labelNames = new ArrayList<String>();
        for (String label : _labelList) {
            if (!(_labelNames.contains(label))) {
                _labelNames.add(label);
            }
        }
        Collections.sort(_labelNames);
        List<Double> labelIndexList = new ArrayList<Double>(_labelNames.size());
        for (String label : _labelList) {
            int index = Collections.binarySearch(_labelNames, label);
            if (index >= 0) {
                labelIndexList.add((double) index);
            } else {
                throw new RuntimeException("Index not found for label :" + label);
            }
        }

        _featuresList.clear();

        List<Feature[]> featureList = getFeaturesList(vectors);
        int vectorsSize = vectors.get(0).size();
        vectors.clear();

        if (_quietMode) {
            Linear.disableDebugOutput();
        }
        LOGGER.info("Constructing problem for training...");
        Problem problem = Train.constructProblem(labelIndexList, featureList, vectorsSize, -1.0);
        Parameter param = createParameter();
        LOGGER.info("Starting training...");
        _model = Linear.train(problem, param);
        LOGGER.info(String.format("Trained model with %d classes and %d features", _model.getNrClass(),
                _model.getNrFeature()));

        double result = 1.0;
        if (doCrossValidation) {
            double[] target = new double[problem.l];
            LOGGER.info("Cross validating...");
            Linear.crossValidation(problem, param, DEFAULT_NR_FOLD, target);
            int totalCorrect = 0;
            for (int i = 0; i < problem.l; i++) {
                if (target[i] == problem.y[i]) {
                    ++totalCorrect;
                }
            }

            LOGGER.debug(String.format("Correct: %d%n", totalCorrect));
            result = (double) totalCorrect / (double) problem.l;
            LOGGER.debug(String.format("Cross Validation Accuracy = %g%%%n", 100.0 * result));
        }

        return result;
    }

    protected Vector makeNormalizedVector(Map<String, Integer> termMap) {
        // We assume that _uniqueTerms has been set up, as a sorted list, so
        // we can use that to create an appropriate vector.
        Vector result = VectorUtils.makeVector(_uniqueTerms, termMap);
        getNormalizer().normalize(result);

        return result;
    }

    public void train() {
        train(_crossValidationRequired);
    }

    @Override
    public DocDatum classify(TermsDatum datum) {
        Vector docVector = makeNormalizedVector(datum.getTermMap());

        FeatureNode[] features = vectorToFeatureNodes(docVector);
        double[] probEstimates = new double[_labelNames.size()];

        int labelIndex = (int) Linear.predictProbability(_model, features, probEstimates);
        String labelName = _labelNames.get(labelIndex);

        if (_modelLabelIndexes == null) {
            _modelLabelIndexes = _model.getLabels();
        }

        float score = 0;
        // TODO CSc This could be made more efficient than a linear search
        for (int i = 0; i < _modelLabelIndexes.length; i++) {
            if (_modelLabelIndexes[i] == labelIndex) {
                score = (float) (probEstimates[i]);
            }
        }
        return new DocDatum(labelName, score);
    }

    @Override
    public DocDatum[] classifyNResults(TermsDatum datum, int n) {
        Vector docVector = makeNormalizedVector(datum.getTermMap());
        FeatureNode[] features = vectorToFeatureNodes(docVector);
        double[] probEstimates = new double[_labelNames.size()];

        //        int topScoreIndex = 
        Linear.predictProbability(_model, features, probEstimates);
        //        String labelName = _labelNames.get(topScoreIndex);

        if (_modelLabelIndexes == null) {
            _modelLabelIndexes = _model.getLabels();
        }

        SortedSet<LabelIndexScore> labelIndexScoreSet = new TreeSet<LabelIndexScore>();
        double lowestScore = 0.0;
        for (int i = 0; i < probEstimates.length; i++) {
            double score = probEstimates[i];

            if (labelIndexScoreSet.size() >= n) {
                if (score > lowestScore) {
                    LabelIndexScore indexScore = new LabelIndexScore(_modelLabelIndexes[i], score);
                    LabelIndexScore first = labelIndexScoreSet.first();
                    labelIndexScoreSet.remove(first);
                    labelIndexScoreSet.add(indexScore);
                    // And now get the new lowest score
                    lowestScore = labelIndexScoreSet.first().getScore();
                }
            } else {
                LabelIndexScore indexScore = new LabelIndexScore(_modelLabelIndexes[i], score);
                labelIndexScoreSet.add(indexScore);
                lowestScore = labelIndexScoreSet.first().getScore();
            }
        }

        int size = labelIndexScoreSet.size();
        DocDatum[] docDatums = new DocDatum[size];

        // Get the top terms from highest to lowest.
        int i = size - 1;
        Iterator<LabelIndexScore> iter = labelIndexScoreSet.iterator();
        while (iter.hasNext() && i >= 0) {
            LabelIndexScore next = iter.next();
            docDatums[i] = new DocDatum(_labelNames.get(next.getLabelIndex()), (float) next.getScore());
            i--;
        }

        //        assert(_labelNames.get(topScoreIndex).equals(docDatums[0].getLabel()));
        return docDatums;
    }

    @Override
    public String getDetails() {
        StringBuilder result = new StringBuilder(super.getDetails());
        double[] weights = _model.getFeatureWeights();

        for (int i = 0; i < _uniqueTerms.size(); i++) {
            result.append(String.format("\t%s: %f\n", _uniqueTerms.get(i), weights[i]));
        }

        return result.toString();
    }

    private List<Feature[]> getFeaturesList(List<Vector> vectors) {
        List<Feature[]> result = new ArrayList<Feature[]>();
        for (Vector vector : vectors) {
            FeatureNode[] x = vectorToFeatureNodes(vector);
            result.add(x);
        }
        return result;
    }

    private FeatureNode[] vectorToFeatureNodes(Vector vector) {
        int featureCount = vector.getNumNondefaultElements();
        FeatureNode[] x = new FeatureNode[featureCount];
        int arrayIndex = 0;
        int cardinality = vector.size();
        for (int i = 0; i < cardinality; i++) {
            double value = vector.getQuick(i);
            if (value != 0.0) {
                // (At least) Linear.train assumes that FeatureNode.index
                // is 1-based, and we don't really have to map back to our
                // term indexes, so just add one. YUCK!
                x[arrayIndex++] = new FeatureNode(i + 1, value);
            }
        }
        return x;
    }

    private List<String> buildUniqueTerms(List<Map<String, Integer>> featuresList) {
        Set<String> uniqueTerms = new HashSet<String>();
        for (Map<String, Integer> termMap : featuresList) {
            uniqueTerms.addAll(termMap.keySet());
        }
        List<String> sortedTerms = new ArrayList<String>(uniqueTerms);
        Collections.sort(sortedTerms);
        return sortedTerms;
    }

}