com.joliciel.talismane.machineLearning.linearsvm.LinearSVMOneVsRestDecisionMaker.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.talismane.machineLearning.linearsvm.LinearSVMOneVsRestDecisionMaker.java

Source

///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2014 Joliciel Informatique
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see <http://www.gnu.org/licenses/>.
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.linearsvm;

import gnu.trove.map.TObjectIntMap;

import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.joliciel.talismane.machineLearning.ClassificationSolution;
import com.joliciel.talismane.machineLearning.Decision;
import com.joliciel.talismane.machineLearning.DecisionFactory;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.GeometricMeanScoringStrategy;
import com.joliciel.talismane.machineLearning.Outcome;
import com.joliciel.talismane.machineLearning.ScoringStrategy;
import com.joliciel.talismane.machineLearning.features.FeatureResult;
import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;

class LinearSVMOneVsRestDecisionMaker<T extends Outcome> implements DecisionMaker<T> {
    private static final Log LOG = LogFactory.getLog(LinearSVMOneVsRestDecisionMaker.class);
    private DecisionFactory<T> decisionFactory;
    private List<Model> models;
    TObjectIntMap<String> featureIndexMap = null;
    List<String> outcomes = null;
    private transient ScoringStrategy<ClassificationSolution<T>> scoringStrategy = null;

    public LinearSVMOneVsRestDecisionMaker(List<Model> models, TObjectIntMap<String> featureIndexMap,
            List<String> outcomes) {
        super();
        this.models = models;
        this.featureIndexMap = featureIndexMap;
        this.outcomes = outcomes;
    }

    @Override
    public List<Decision<T>> decide(List<FeatureResult<?>> featureResults) {
        List<Feature> featureList = LinearSVMUtils.prepareData(featureResults, featureIndexMap);

        List<Decision<T>> decisions = null;

        if (featureList.size() == 0) {
            LOG.info("No features for current context.");
            TreeSet<Decision<T>> outcomeSet = new TreeSet<Decision<T>>();
            double uniformProb = 1 / outcomes.size();
            for (String outcome : outcomes) {
                Decision<T> decision = decisionFactory.createDecision(outcome, uniformProb);
                outcomeSet.add(decision);
            }
            decisions = new ArrayList<Decision<T>>(outcomeSet);
        } else {
            Feature[] instance = new Feature[1];
            instance = featureList.toArray(instance);

            TreeSet<Decision<T>> outcomeSet = new TreeSet<Decision<T>>();

            int i = 0;
            for (Model model : models) {
                int myLabel = 0;
                for (int j = 0; j < model.getLabels().length; j++)
                    if (model.getLabels()[j] == 1)
                        myLabel = j;
                double[] probabilities = new double[2];
                Linear.predictProbability(model, instance, probabilities);

                Decision<T> decision = decisionFactory.createDecision(outcomes.get(i), probabilities[myLabel]);
                outcomeSet.add(decision);
                i++;
            }
            decisions = new ArrayList<Decision<T>>(outcomeSet);
        }
        return decisions;
    }

    public DecisionFactory<T> getDecisionFactory() {
        return decisionFactory;
    }

    public void setDecisionFactory(DecisionFactory<T> decisionFactory) {
        this.decisionFactory = decisionFactory;
    }

    @Override
    public ScoringStrategy<ClassificationSolution<T>> getDefaultScoringStrategy() {
        if (scoringStrategy == null)
            scoringStrategy = new GeometricMeanScoringStrategy<T>();
        return scoringStrategy;
    }
}