com.scaleunlimited.classify.model.HashedFeaturesLibLinearModel.java Source code

Java tutorial

Introduction

Here is the source code for com.scaleunlimited.classify.model.HashedFeaturesLibLinearModel.java

Source

/**
 * Copyright (c) 2009-2015 Scale Unlimited, Inc.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.scaleunlimited.classify.model;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.scaleunlimited.classify.datum.DocDatum;
import com.scaleunlimited.classify.datum.TermsDatum;

import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.Train;

@SuppressWarnings("serial")
public class HashedFeaturesLibLinearModel extends BaseLibLinearModel {
    private static final Logger LOGGER = LoggerFactory.getLogger(HashedFeaturesLibLinearModel.class);

    // if num features * percent reduction is less than this, keep all of the features
    // (no reduction)
    private static final int MIN_FEATURE_SIZE = 10;

    // We generate this during training
    private int _maxFeatureIndex;

    // Values we need during training only, thus not saved
    private transient float _percentReduction = 0.10f;
    private transient boolean _averageCollisions = true;

    public HashedFeaturesLibLinearModel() {
        super();
    }

    public HashedFeaturesLibLinearModel setPercentReduction(float percentReduction) {
        _percentReduction = percentReduction;
        return this;
    }

    public HashedFeaturesLibLinearModel setAverageCollisions(boolean averageCollisions) {
        _averageCollisions = averageCollisions;
        return this;
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        _maxFeatureIndex = in.readInt();
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        out.writeInt(_maxFeatureIndex);
    }

    @Override
    public void train() {
        train(_crossValidationRequired);
    }

    @Override
    public double train(boolean doCrossValidation) {
        // First generate list of unique labels, so we can map a label to an index.
        _labelNames = new ArrayList<String>();
        for (String label : _labelList) {
            if (!(_labelNames.contains(label))) {
                _labelNames.add(label);
            }
        }
        Collections.sort(_labelNames);

        // Create list that maps from training data set index to label index
        List<Double> labelIndexList = new ArrayList<Double>(_labelList.size());
        for (String label : _labelList) {
            int index = Collections.binarySearch(_labelNames, label);
            if (index >= 0) {
                labelIndexList.add((double) index);
            } else {
                throw new RuntimeException("Index not found for label :" + label);
            }
        }

        // Figure out the max index, by counting # of unique features, and reducing
        // down to some percentage of this count. But we want at least MIN_FEATURE_SIZE, so if
        // we're below that, just set it to the # of features - 1 (so some hashing
        // will occur, for testing).
        Set<String> uniqueFeatures = new HashSet<String>();
        for (Map<String, Integer> termsMap : _featuresList) {
            uniqueFeatures.addAll(termsMap.keySet());
        }

        _maxFeatureIndex = Math.round(uniqueFeatures.size() * _percentReduction);
        LOGGER.debug(String.format("Setting max feature index to be %d", _maxFeatureIndex));
        if (_maxFeatureIndex < MIN_FEATURE_SIZE) {
            _maxFeatureIndex = uniqueFeatures.size() - 1;
            LOGGER.debug(String.format("Resetting max feature index to be %d", _maxFeatureIndex));
        }

        List<Feature[]> features = new ArrayList<Feature[]>(_featuresList.size());
        for (Map<String, Integer> termsMap : _featuresList) {
            features.add(getFeatures(termsMap));
        }

        _featuresList.clear();

        if (_quietMode) {
            Linear.disableDebugOutput();
        }

        LOGGER.debug("Constructing problem for training...");
        Problem problem = Train.constructProblem(labelIndexList, features, _maxFeatureIndex + 1, -1.0);
        Parameter param = createParameter();

        LOGGER.debug("Starting training...");
        _model = Linear.train(problem, param);
        LOGGER.debug(String.format("Trained model with %d classes and %d features", _model.getNrClass(),
                _model.getNrFeature()));

        double crossValidationAccuracy = 0.0;
        if (doCrossValidation) {
            double[] target = new double[problem.l];
            LOGGER.debug("Cross validating...");
            Linear.crossValidation(problem, param, DEFAULT_NR_FOLD, target);
            int totalCorrect = 0;
            for (int i = 0; i < problem.l; i++) {
                if (target[i] == problem.y[i]) {
                    ++totalCorrect;
                }
            }

            crossValidationAccuracy = (double) totalCorrect / (double) problem.l;
            LOGGER.debug(String.format("Correct: %d%n", totalCorrect));
            LOGGER.debug(String.format("Cross Validation Accuracy = %g%%%n", 100.0 * crossValidationAccuracy));
        }

        return crossValidationAccuracy;
    }

    @Override
    public DocDatum classify(TermsDatum datum) {
        Feature[] features = getFeatures(datum.getTermMap());
        double[] probEstimates = new double[_labelNames.size()];

        int labelIndex = (int) Linear.predictProbability(_model, features, probEstimates);
        String labelName = _labelNames.get(labelIndex);

        if (_modelLabelIndexes == null) {
            _modelLabelIndexes = _model.getLabels();
        }

        float score = 0;
        // FUTURE CSc This could be made more efficient than a linear search
        for (int i = 0; i < _modelLabelIndexes.length; i++) {
            if (_modelLabelIndexes[i] == labelIndex) {
                score = (float) (probEstimates[i]);
            }
        }

        return new DocDatum(labelName, score);
    }

    public DocDatum[] classifyNResults(TermsDatum datum, int n) {
        Feature[] features = getFeatures(datum.getTermMap());
        double[] probEstimates = new double[_labelNames.size()];

        //        int topScoreIndex = 
        Linear.predictProbability(_model, features, probEstimates);
        //        String labelName = _labelNames.get(topScoreIndex);

        if (_modelLabelIndexes == null) {
            _modelLabelIndexes = _model.getLabels();
        }

        SortedSet<LabelIndexScore> labelIndexScoreSet = new TreeSet<LabelIndexScore>();
        double lowestScore = 0.0;
        for (int i = 0; i < probEstimates.length; i++) {
            double score = probEstimates[i];

            if (labelIndexScoreSet.size() >= n) {
                if (score > lowestScore) {
                    LabelIndexScore indexScore = new LabelIndexScore(_modelLabelIndexes[i], score);
                    LabelIndexScore first = labelIndexScoreSet.first();
                    labelIndexScoreSet.remove(first);
                    labelIndexScoreSet.add(indexScore);
                    // And now get the new lowest score
                    lowestScore = labelIndexScoreSet.first().getScore();
                }
            } else {
                LabelIndexScore indexScore = new LabelIndexScore(_modelLabelIndexes[i], score);
                labelIndexScoreSet.add(indexScore);
                lowestScore = labelIndexScoreSet.first().getScore();
            }
        }

        int size = labelIndexScoreSet.size();
        DocDatum[] docDatums = new DocDatum[size];

        // Get the top terms from highest to lowest.
        int i = size - 1;
        Iterator<LabelIndexScore> iter = labelIndexScoreSet.iterator();
        while (iter.hasNext() && i >= 0) {
            LabelIndexScore next = iter.next();
            docDatums[i] = new DocDatum(_labelNames.get(next.getLabelIndex()), (float) next.getScore());
            i--;
        }

        //        assert(_labelNames.get(topScoreIndex).equals(docDatums[0].getLabel()));
        return docDatums;
    }

    @Override
    public String getDetails() {
        StringBuilder result = new StringBuilder(super.getDetails());
        // TODO output extra info about reduction amount?

        return result.toString();
    }

    public static int calcHashBuiltin(String term, int modulo) {
        return (int) ((term.hashCode() & 0x07FFFFFFF) % modulo);
    }

    private static final long[] lookupTable = createLookupTable();
    private static final long HSTART = 0xBB40E64DA205B064L;
    private static final long HMULT = 7664345821815920749L;

    private static final long[] createLookupTable() {
        long[] byteTable = new long[256];
        long h = 0x544B2FBACAAF1684L;
        for (int i = 0; i < 256; i++) {
            for (int j = 0; j < 31; j++) {
                h = (h >>> 7) ^ h;
                h = (h << 11) ^ h;
                h = (h >>> 10) ^ h;
            }
            byteTable[i] = h;
        }
        return byteTable;
    }

    public static int calcHashLCG(String term, int modulo) {
        byte[] data;
        try {
            data = term.getBytes("UTF-8");
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException("Impossible exception", e);
        }

        long h = HSTART;
        final long hmult = HMULT;
        final long[] ht = lookupTable;
        for (int len = data.length, i = 0; i < len; i++) {
            h = (h * hmult) ^ ht[data[i] & 0xff];
        }

        return (int) ((h & 0x07FFFFFFF) % modulo);
    }

    public static int calcHashJoaat(String term, int modulo) {
        byte[] key;

        try {
            key = term.getBytes("UTF-8");
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException("Impossible exception", e);
        }

        int hash = 0;

        for (byte b : key) {
            hash += (b & 0xFF);
            hash += (hash << 10);
            hash ^= (hash >> 6);
        }

        hash += (hash << 3);
        hash ^= (hash >> 11);
        hash += (hash << 15);

        return Math.abs(hash) % modulo;
    }

    /**
     * Given a map from term to count, generate a feature array using
     * _maxFeatureIndex as the max index, based on the hash of the term.
     * 
     * @param terms
     * @return array of LibLinear features
     */

    private Feature[] getFeatures(Map<String, Integer> terms) {

        // First create the vector, where each term's index is the hash
        // of the term, and the value is the term count.
        Map<Integer, Integer> collisionCount = new HashMap<>();
        Vector v = new RandomAccessSparseVector(_maxFeatureIndex);
        for (String term : terms.keySet()) {
            int index = calcHashJoaat(term, _maxFeatureIndex);
            double curValue = v.getQuick(index);
            if (_averageCollisions && (curValue != 0.0)) {
                Integer curCollisionCount = collisionCount.get(index);
                if (curCollisionCount == null) {
                    // Number of values we'll need to divide by
                    collisionCount.put(index, 2);
                } else {
                    collisionCount.put(index, curCollisionCount + 1);
                }

                v.setQuick(index, curValue + terms.get(term));
            } else {
                v.setQuick(index, terms.get(term));
            }
        }

        // Now adjust the vector for collisions, if needed.
        if (_averageCollisions && !collisionCount.isEmpty()) {
            for (Integer index : collisionCount.keySet()) {
                double curValue = v.getQuick(index);
                v.setQuick(index, curValue / collisionCount.get(index));
            }
        }

        // Apply the term vector normalizer.
        getNormalizer().normalize(v);

        List<FeatureNode> features = new ArrayList<FeatureNode>(terms.size());
        for (Element e : v.nonZeroes()) {
            features.add(new FeatureNode(e.index() + 1, e.get()));
        }

        // We need to sort by increasing index.
        Collections.sort(features, new Comparator<FeatureNode>() {

            @Override
            public int compare(FeatureNode o1, FeatureNode o2) {
                return o1.index - o2.index;
            }
        });

        return features.toArray(new FeatureNode[features.size()]);
    }

}