etc.aloe.cscw2013.SMOFeatureWeighting.java Source code

Java tutorial

Introduction

Here is the source code for etc.aloe.cscw2013.SMOFeatureWeighting.java

Source

/*
 * This file is part of ALOE.
 *
 * ALOE is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
    
 * ALOE is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
    
 * You should have received a copy of the GNU General Public License
 * along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
 */
package etc.aloe.cscw2013;

import etc.aloe.data.ExampleSet;
import etc.aloe.data.Model;
import etc.aloe.processes.FeatureWeighting;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import weka.classifiers.Classifier;
import weka.classifiers.functions.SMO;
import weka.classifiers.meta.CostSensitiveClassifier;
import weka.core.Instances;

/**
 * Extracts top features and feature weights from a linear support vector
 * machine (SMO) classifier.
 *
 * Also works with a CostSensitiveClassifier wrapping an SMO.
 *
 * @author Michael Brooks <mjbrooks@uw.edu>
 */
public class SMOFeatureWeighting implements FeatureWeighting {

    @Override
    public List<String> getTopFeatures(ExampleSet trainingExamples, Model model, int topN) {

        List<Map.Entry<String, Double>> weights = getFeatureWeights(trainingExamples, model);
        Collections.sort(weights, new Comparator<Map.Entry<String, Double>>() {
            @Override
            public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
                return -Double.compare(o1.getValue() * o1.getValue(), o2.getValue() * o2.getValue());
            }
        });

        List<String> result = new ArrayList<String>();
        for (int i = 0; i < topN && i < weights.size(); i++) {
            Map.Entry<String, Double> entry = weights.get(i);

            result.add(entry.getKey());
        }

        return result;
    }

    @Override
    public List<Entry<String, Double>> getFeatureWeights(ExampleSet trainingExamples, Model model) {
        WekaModel wekaModel = (WekaModel) model;
        Classifier classifier = wekaModel.getClassifier();
        Instances dataFormat = trainingExamples.getInstances();

        SMO smo = getSMO(classifier);

        double[] sparseWeights = smo.sparseWeights()[0][1];
        int[] sparseIndices = smo.sparseIndices()[0][1];

        Map<String, Double> weights = new HashMap<String, Double>();
        for (int i = 0; i < sparseWeights.length; i++) {
            int index = sparseIndices[i];
            double weight = sparseWeights[i];
            String name = dataFormat.attribute(index).name();
            weights.put(name, weight);
        }

        List<Map.Entry<String, Double>> entries = new ArrayList<Map.Entry<String, Double>>(weights.entrySet());

        Collections.sort(entries, new Comparator<Map.Entry<String, Double>>() {
            @Override
            public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
                return o1.getKey().compareTo(o2.getKey());
            }
        });

        return entries;
    }

    /**
     * Given a classifier, attempts to cast it to an SMO or get the contained
     * SMO.
     *
     * @param classifier
     * @return
     */
    private SMO getSMO(Classifier classifier) {
        if (classifier instanceof CostSensitiveClassifier) {
            classifier = ((CostSensitiveClassifier) classifier).getClassifier();
        }

        SMO smo = null;
        if (classifier instanceof SMO) {
            smo = (SMO) classifier;
        } else {
            throw new IllegalArgumentException("Classifier was neither SMO or CostSensitiveClassifier(SMO)");
        }

        return smo;
    }
}