adams.opt.optimise.genetic.fitnessfunctions.AttributeSelection.java Source code

Java tutorial

Introduction

Here is the source code for adams.opt.optimise.genetic.fitnessfunctions.AttributeSelection.java

Source

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * AttributeSelection.java
 * Copyright (C) 2012-2015 University of Waikato, Hamilton, New Zealand
 */

package adams.opt.optimise.genetic.fitnessfunctions;

import adams.core.option.OptionUtils;
import adams.opt.optimise.OptData;
import adams.opt.optimise.OptVar;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.unsupervised.attribute.Remove;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.Writer;
import java.util.Random;
import java.util.logging.Level;

/**
 * Perform attribute selection using WEKA classification.
 * @author dale
 * @version $Revision$
 */
public class AttributeSelection extends AbstractWEKAFitnessFunction {

    /**suid.  */
    private static final long serialVersionUID = 1967190416117903831L;

    @Override
    public String globalInfo() {
        return "Attribute selection";
    }

    public OptData getDataDef() {
        init();
        OptData odd = new OptData();
        for (int i = 0; i < m_Instances.numAttributes() - 1; i++) {
            odd.add(new OptVar("" + i, 0, 1, true));
        }
        return (odd);
    }

    protected int[] getWeights(OptData opd) {
        int[] weights = new int[getInstances().numAttributes() - 1];
        int cnt = 0;
        for (int i = 0; i < getInstances().numAttributes(); i++) {
            if (i == getInstances().classIndex()) {
                continue;
            }
            weights[cnt] = opd.get("" + cnt).intValue();
            cnt++;
        }
        return (weights);
    }

    public double evaluate(OptData opd) {
        init();
        int cnt = 0;
        int[] weights = getWeights(opd);
        Instances newInstances = new Instances(getInstances());
        for (int i = 0; i < getInstances().numInstances(); i++) {
            Instance in = newInstances.instance(i);
            cnt = 0;
            for (int a = 0; a < getInstances().numAttributes(); a++) {
                if (a == getInstances().classIndex())
                    continue;
                if (weights[cnt++] == 0) {
                    in.setValue(a, 0);
                } else {
                    in.setValue(a, in.value(a));
                }
            }
        }
        Classifier newClassifier = null;

        try {
            newClassifier = (Classifier) OptionUtils.shallowCopy(getClassifier());
            // evaluate classifier on data
            Evaluation evaluation = new Evaluation(newInstances);
            evaluation.crossValidateModel(newClassifier, newInstances, getFolds(),
                    new Random(getCrossValidationSeed()));

            // obtain measure
            double measure = 0;
            if (getMeasure() == Measure.ACC)
                measure = evaluation.pctCorrect();
            else if (getMeasure() == Measure.CC)
                measure = evaluation.correlationCoefficient();
            else if (getMeasure() == Measure.MAE)
                measure = evaluation.meanAbsoluteError();
            else if (getMeasure() == Measure.RAE)
                measure = evaluation.relativeAbsoluteError();
            else if (getMeasure() == Measure.RMSE)
                measure = evaluation.rootMeanSquaredError();
            else if (getMeasure() == Measure.RRSE)
                measure = evaluation.rootRelativeSquaredError();
            else
                throw new IllegalStateException("Unhandled measure '" + getMeasure() + "'!");
            measure = getMeasure().adjust(measure);

            return (measure);
            // process fitness

        } catch (Exception e) {
            getLogger().log(Level.SEVERE, "Error evaluating", e);
        }

        return 0;
    }

    /**
     * Generates a range string of attributes to keep (= one has to use
     * the inverse matching sense with the Remove filter).
     *
     * @return      the range of attributes to keep
     */
    public String getRemoveAsString(int[] m_weights) {
        String ret = "";
        int pos = 0;
        int last = -1;
        boolean thefirst = true;
        for (int a = 0; a < getInstances().numAttributes() - 1; a++) {
            if (m_weights[a] == 0 && a != getInstances().classIndex()) {
                if (last == -1)
                    continue;
                if (thefirst)
                    thefirst = false;
                else
                    ret = (new StringBuilder(String.valueOf(ret))).append(",").toString();
                if (pos - last > 1)
                    ret = (new StringBuilder(String.valueOf(ret))).append(last + 1).append("-").append(pos + 1)
                            .toString();
                else if (pos - last == 1)
                    ret = (new StringBuilder(String.valueOf(ret))).append(last + 1).append(",").append(pos + 1)
                            .toString();
                else
                    ret = (new StringBuilder(String.valueOf(ret))).append(last + 1).toString();
                last = -1;
            }
            if (m_weights[a] != 0 || a == getInstances().classIndex()) {
                if (last == -1)
                    last = a;
                pos = a;
            }
        }

        if (last != -1) {
            if (!thefirst)
                ret = (new StringBuilder(String.valueOf(ret))).append(",").toString();
            if (pos - last > 1)
                ret = (new StringBuilder(String.valueOf(ret))).append(last + 1).append("-").append(pos + 1)
                        .toString();
            else if (pos - last == 1)
                ret = (new StringBuilder(String.valueOf(ret))).append(last + 1).append(",").append(pos + 1)
                        .toString();
            else
                ret = (new StringBuilder(String.valueOf(ret))).append(last + 1).toString();
        }
        return ret;

    }

    /**
     * Callback for best measure so far
     */
    @Override
    public void newBest(double val, OptData opd) {
        int cnt = 0;
        int[] weights = getWeights(opd);
        Instances newInstances = new Instances(getInstances());
        for (int i = 0; i < getInstances().numInstances(); i++) {
            Instance in = newInstances.instance(i);
            cnt = 0;
            for (int a = 0; a < getInstances().numAttributes(); a++) {
                if (a == getInstances().classIndex())
                    continue;
                if (weights[cnt++] == 0) {
                    in.setValue(a, 0);
                } else {
                    in.setValue(a, in.value(a));
                }
            }
        }
        try {
            File file = new File(getOutputDirectory().getAbsolutePath() + File.separator
                    + Double.toString(getMeasure().adjust(val)) + ".arff");
            file.createNewFile();
            Writer writer = new BufferedWriter(new FileWriter(file));
            Instances header = new Instances(newInstances, 0);

            // remove filter setup
            Remove remove = new Remove();
            remove.setAttributeIndices(getRemoveAsString(weights));
            remove.setInvertSelection(true);

            header.setRelationName(OptionUtils.getCommandLine(remove));

            writer.write(header.toString());
            writer.write("\n");
            for (int i = 0; i < newInstances.numInstances(); i++) {
                writer.write(newInstances.instance(i).toString());
                writer.write("\n");
            }
            writer.flush();
            writer.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}