com.joliciel.talismane.machineLearning.linearsvm.LinearSVMDecisionMaker.java Source code

Java tutorial

Introduction

Here is the source code for com.joliciel.talismane.machineLearning.linearsvm.LinearSVMDecisionMaker.java

Source

///////////////////////////////////////////////////////////////////////////////
//Copyright (C) 2012 Assaf Urieli
//
//This file is part of Talismane.
//
//Talismane is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//Talismane is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with Talismane.  If not, see <http://www.gnu.org/licenses/>.
//////////////////////////////////////////////////////////////////////////////
package com.joliciel.talismane.machineLearning.linearsvm;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.joliciel.talismane.machineLearning.ClassificationSolution;
import com.joliciel.talismane.machineLearning.Decision;
import com.joliciel.talismane.machineLearning.DecisionFactory;
import com.joliciel.talismane.machineLearning.DecisionMaker;
import com.joliciel.talismane.machineLearning.GeometricMeanScoringStrategy;
import com.joliciel.talismane.machineLearning.Outcome;
import com.joliciel.talismane.machineLearning.ScoringStrategy;
import com.joliciel.talismane.machineLearning.features.FeatureResult;
import com.joliciel.talismane.utils.WeightedOutcome;

import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;

class LinearSVMDecisionMaker<T extends Outcome> implements DecisionMaker<T> {
    private static final Log LOG = LogFactory.getLog(LinearSVMDecisionMaker.class);
    private DecisionFactory<T> decisionFactory;
    private Model model;
    Map<String, Integer> featureIndexMap = null;
    List<String> outcomes = null;
    private transient ScoringStrategy<ClassificationSolution<T>> scoringStrategy = null;

    public LinearSVMDecisionMaker(Model model, Map<String, Integer> featureIndexMap, List<String> outcomes) {
        super();
        this.model = model;
        this.featureIndexMap = featureIndexMap;
        this.outcomes = outcomes;
    }

    @Override
    public List<Decision<T>> decide(List<FeatureResult<?>> featureResults) {
        List<Feature> featureList = new ArrayList<Feature>(featureResults.size());
        this.prepareData(featureResults, featureList);

        List<Decision<T>> decisions = null;

        if (featureList.size() == 0) {
            LOG.info("No features for current context.");
            TreeSet<Decision<T>> outcomeSet = new TreeSet<Decision<T>>();
            double uniformProb = 1 / outcomes.size();
            for (String outcome : outcomes) {
                Decision<T> decision = decisionFactory.createDecision(outcome, uniformProb);
                outcomeSet.add(decision);
            }
            decisions = new ArrayList<Decision<T>>(outcomeSet);
        } else {
            Feature[] instance = new Feature[1];
            instance = featureList.toArray(instance);

            double[] probabilities = new double[model.getLabels().length];
            Linear.predictProbability(model, instance, probabilities);

            TreeSet<Decision<T>> outcomeSet = new TreeSet<Decision<T>>();
            for (int i = 0; i < model.getLabels().length; i++) {
                Decision<T> decision = decisionFactory.createDecision(outcomes.get(i), probabilities[i]);
                outcomeSet.add(decision);
            }
            decisions = new ArrayList<Decision<T>>(outcomeSet);
        }

        return decisions;

    }

    void prepareData(List<FeatureResult<?>> featureResults, List<Feature> featureList) {
        for (FeatureResult<?> featureResult : featureResults) {
            if (featureResult.getOutcome() instanceof List) {
                @SuppressWarnings("unchecked")
                FeatureResult<List<WeightedOutcome<String>>> stringCollectionResult = (FeatureResult<List<WeightedOutcome<String>>>) featureResult;
                for (WeightedOutcome<String> stringOutcome : stringCollectionResult.getOutcome()) {
                    Integer index = featureIndexMap.get(featureResult.getTrainingName() + "|"
                            + featureResult.getTrainingOutcome(stringOutcome.getOutcome()));
                    if (index != null) {
                        double value = stringOutcome.getWeight();
                        FeatureNode featureNode = new FeatureNode(index.intValue(), value);
                        featureList.add(featureNode);
                    }
                }
            } else {
                double value = 1.0;

                if (featureResult.getOutcome() instanceof Double) {
                    @SuppressWarnings("unchecked")
                    FeatureResult<Double> doubleResult = (FeatureResult<Double>) featureResult;
                    value = doubleResult.getOutcome().doubleValue();
                }
                Integer index = featureIndexMap.get(featureResult.getTrainingName());
                if (index != null) {
                    // we only need to bother adding features which existed in the training set
                    FeatureNode featureNode = new FeatureNode(index.intValue(), value);
                    featureList.add(featureNode);
                }
            }
        }
    }

    public DecisionFactory<T> getDecisionFactory() {
        return decisionFactory;
    }

    public void setDecisionFactory(DecisionFactory<T> decisionFactory) {
        this.decisionFactory = decisionFactory;
    }

    @Override
    public ScoringStrategy<ClassificationSolution<T>> getDefaultScoringStrategy() {
        if (scoringStrategy == null)
            scoringStrategy = new GeometricMeanScoringStrategy<T>();
        return scoringStrategy;
    }
}