Example usage for weka.core Instances setClassIndex

List of usage examples for weka.core Instances setClassIndex

Introduction

In this page you can find the example usage for weka.core Instances setClassIndex.

Prototype

public void setClassIndex(int classIndex) 

Source Link

Document

Sets the class index of the set.

Usage

From source file:adams.flow.transformer.WekaPredictionsToInstances.java

License:Open Source License

/**
 * Executes the flow item.//from  ww w.j a  va 2 s  .c o m
 *
 * @return      null if everything is fine, otherwise error message
 */
@Override
protected String doExecute() {
    String result;
    Evaluation eval;
    int i;
    int n;
    int indexErr;
    int indexProb;
    int indexDist;
    int indexWeight;
    boolean nominal;
    Instances header;
    ArrayList<Attribute> atts;
    ArrayList<String> values;
    ArrayList<Prediction> predictions;
    Prediction pred;
    double[] vals;
    Instances data;
    Instances testData;
    int[] indices;

    result = null;

    if (m_InputToken.getPayload() instanceof WekaEvaluationContainer) {
        eval = (Evaluation) ((WekaEvaluationContainer) m_InputToken.getPayload())
                .getValue(WekaEvaluationContainer.VALUE_EVALUATION);
        indices = (int[]) ((WekaEvaluationContainer) m_InputToken.getPayload())
                .getValue(WekaEvaluationContainer.VALUE_ORIGINALINDICES);
        testData = (Instances) ((WekaEvaluationContainer) m_InputToken.getPayload())
                .getValue(WekaEvaluationContainer.VALUE_TESTDATA);
    } else {
        eval = (Evaluation) m_InputToken.getPayload();
        indices = null;
        testData = null;
    }
    header = eval.getHeader();
    nominal = header.classAttribute().isNominal();
    predictions = eval.predictions();

    if (predictions != null) {
        // create header
        atts = new ArrayList<>();
        // actual
        if (nominal && m_AddLabelIndex) {
            values = new ArrayList<>();
            for (i = 0; i < header.classAttribute().numValues(); i++)
                values.add((i + 1) + ":" + header.classAttribute().value(i));
            atts.add(new Attribute(m_MeasuresPrefix + "Actual", values));
        } else {
            atts.add(header.classAttribute().copy(m_MeasuresPrefix + "Actual"));
        }
        // predicted
        if (nominal && m_AddLabelIndex) {
            values = new ArrayList<>();
            for (i = 0; i < header.classAttribute().numValues(); i++)
                values.add((i + 1) + ":" + header.classAttribute().value(i));
            atts.add(new Attribute(m_MeasuresPrefix + "Predicted", values));
        } else {
            atts.add(header.classAttribute().copy(m_MeasuresPrefix + "Predicted"));
        }
        // error
        indexErr = -1;
        if (m_ShowError) {
            indexErr = atts.size();
            if (nominal) {
                values = new ArrayList<>();
                values.add("n");
                values.add("y");
                atts.add(new Attribute(m_MeasuresPrefix + "Error", values));
            } else {
                atts.add(new Attribute(m_MeasuresPrefix + "Error"));
            }
        }
        // probability
        indexProb = -1;
        if (m_ShowProbability && nominal) {
            indexProb = atts.size();
            atts.add(new Attribute(m_MeasuresPrefix + "Probability"));
        }
        // distribution
        indexDist = -1;
        if (m_ShowDistribution && nominal) {
            indexDist = atts.size();
            for (n = 0; n < header.classAttribute().numValues(); n++)
                atts.add(new Attribute(
                        m_MeasuresPrefix + "Distribution (" + header.classAttribute().value(n) + ")"));
        }
        // weight
        indexWeight = -1;
        if (m_ShowWeight) {
            indexWeight = atts.size();
            atts.add(new Attribute(m_MeasuresPrefix + "Weight"));
        }

        data = new Instances("Predictions", atts, predictions.size());
        data.setClassIndex(1); // predicted

        // add data
        if ((indices != null) && m_UseOriginalIndices)
            predictions = CrossValidationHelper.alignPredictions(predictions, indices);
        for (i = 0; i < predictions.size(); i++) {
            pred = predictions.get(i);
            vals = new double[data.numAttributes()];
            // actual
            vals[0] = pred.actual();
            // predicted
            vals[1] = pred.predicted();
            // error
            if (m_ShowError) {
                if (nominal) {
                    vals[indexErr] = ((pred.actual() != pred.predicted()) ? 1.0 : 0.0);
                } else {
                    if (m_UseAbsoluteError)
                        vals[indexErr] = Math.abs(pred.actual() - pred.predicted());
                    else
                        vals[indexErr] = pred.actual() - pred.predicted();
                }
            }
            // probability
            if (m_ShowProbability && nominal) {
                vals[indexProb] = StatUtils.max(((NominalPrediction) pred).distribution());
            }
            // distribution
            if (m_ShowDistribution && nominal) {
                for (n = 0; n < header.classAttribute().numValues(); n++)
                    vals[indexDist + n] = ((NominalPrediction) pred).distribution()[n];
            }
            // weight
            if (m_ShowWeight) {
                vals[indexWeight] = pred.weight();
            }
            // add row
            data.add(new DenseInstance(1.0, vals));
        }

        // add test data?
        if ((testData != null) && !m_TestAttributes.isEmpty()) {
            testData = filterTestData(testData);
            if (testData != null)
                data = Instances.mergeInstances(data, testData);
        }

        // generate output token
        m_OutputToken = new Token(data);
    } else {
        getLogger().severe("No predictions available from Evaluation object!");
    }

    return result;
}

From source file:adams.gui.menu.CostCurve.java

License:Open Source License

/**
 * Launches the functionality of the menu item.
 *//* w ww . j a  va2 s. c  om*/
@Override
public void launch() {
    File file;
    if (m_Parameters.length == 0) {
        // choose file
        int retVal = m_FileChooser.showOpenDialog(null);
        if (retVal != JFileChooser.APPROVE_OPTION)
            return;
        file = m_FileChooser.getSelectedFile();
    } else {
        file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile();
        m_FileChooser.setSelectedFile(file);
    }

    // create plot
    Instances result;
    try {
        result = m_FileChooser.getLoader().getDataSet();
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(),
                "Error loading file '" + file + "':\n" + adams.core.Utils.throwableToString(e));
        return;
    }
    result.setClassIndex(result.numAttributes() - 1);
    ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
    PlotData2D plot = new PlotData2D(result);
    plot.setPlotName(result.relationName());
    plot.m_displayAllPoints = true;
    boolean[] connectPoints = new boolean[result.numInstances()];
    for (int cp = 1; cp < connectPoints.length; cp++)
        connectPoints[cp] = true;
    try {
        plot.setConnectPoints(connectPoints);
        vmc.addPlot(plot);
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(), "Error adding plot:\n" + adams.core.Utils.throwableToString(e));
        return;
    }

    ChildFrame frame = createChildFrame(vmc, GUIHelper.getDefaultDialogDimension());
    frame.setTitle(frame.getTitle() + " - " + file);
}

From source file:adams.gui.menu.InstancesPlot.java

License:Open Source License

/**
 * Launches the functionality of the menu item.
 *///w  w w.  j  a  v a  2  s  .c  o  m
@Override
public void launch() {
    File file;
    AbstractFileLoader loader;
    if (m_Parameters.length == 0) {
        // choose file
        int retVal = m_FileChooser.showOpenDialog(getOwner());
        if (retVal != JFileChooser.APPROVE_OPTION)
            return;
        file = m_FileChooser.getSelectedFile();
        loader = m_FileChooser.getLoader();
    } else {
        file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile();
        loader = ConverterUtils.getLoaderForFile(file);
    }

    // build plot
    VisualizePanel panel = new VisualizePanel();
    getLogger().severe("Loading instances from " + file);
    try {
        loader.setFile(file);
        Instances i = loader.getDataSet();
        i.setClassIndex(i.numAttributes() - 1);
        PlotData2D pd1 = new PlotData2D(i);
        pd1.setPlotName("Master plot");
        panel.setMasterPlot(pd1);
    } catch (Exception e) {
        getLogger().log(Level.SEVERE, "Failed to load: " + file, e);
        GUIHelper.showErrorMessage(getOwner(),
                "Error loading file '" + file + "':\n" + Utils.throwableToString(e));
        return;
    }

    // create frame
    ChildFrame frame = createChildFrame(panel, GUIHelper.getDefaultDialogDimension());
    frame.setTitle(frame.getTitle() + " - " + file);
}

From source file:adams.gui.menu.MarginCurve.java

License:Open Source License

/**
 * Launches the functionality of the menu item.
 *///from  ww w  . j a  va2s  .c om
@Override
public void launch() {
    File file;
    if (m_Parameters.length == 0) {
        // choose file
        int retVal = m_FileChooser.showOpenDialog(null);
        if (retVal != JFileChooser.APPROVE_OPTION)
            return;
        file = m_FileChooser.getSelectedFile();
    } else {
        file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile();
        m_FileChooser.setSelectedFile(file);
    }

    // create plot
    Instances result;
    try {
        result = m_FileChooser.getLoader().getDataSet();
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(),
                "Error loading file '" + file + "':\n" + adams.core.Utils.throwableToString(e));
        return;
    }
    result.setClassIndex(result.numAttributes() - 1);
    VisualizePanel vp = new VisualizePanel();
    PlotData2D plot = new PlotData2D(result);
    plot.m_displayAllPoints = true;
    boolean[] connectPoints = new boolean[result.numInstances()];
    for (int cp = 1; cp < connectPoints.length; cp++)
        connectPoints[cp] = true;
    try {
        plot.setConnectPoints(connectPoints);
        vp.addPlot(plot);
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(), "Error adding plot:\n" + adams.core.Utils.throwableToString(e));
        return;
    }

    ChildFrame frame = createChildFrame(vp, GUIHelper.getDefaultDialogDimension());
    frame.setTitle(frame.getTitle() + " - " + file);
}

From source file:adams.gui.menu.ROC.java

License:Open Source License

/**
 * Launches the functionality of the menu item.
 *//*from  w  w w . j a v  a  2 s .  com*/
@Override
public void launch() {
    File file;
    if (m_Parameters.length == 0) {
        // choose file
        int retVal = m_FileChooser.showOpenDialog(null);
        if (retVal != JFileChooser.APPROVE_OPTION)
            return;
        file = m_FileChooser.getSelectedFile();
    } else {
        file = new PlaceholderFile(m_Parameters[0]).getAbsoluteFile();
        m_FileChooser.setSelectedFile(file);
    }

    // create plot
    Instances result;
    try {
        result = m_FileChooser.getLoader().getDataSet();
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(),
                "Error loading file '" + file + "':\n" + adams.core.Utils.throwableToString(e));
        return;
    }
    result.setClassIndex(result.numAttributes() - 1);
    ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
    vmc.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
    vmc.setName(result.relationName());
    PlotData2D tempd = new PlotData2D(result);
    tempd.setPlotName(result.relationName());
    tempd.addInstanceNumberAttribute();
    // specify which points are connected
    boolean[] cp = new boolean[result.numInstances()];
    for (int n = 1; n < cp.length; n++)
        cp[n] = true;
    try {
        tempd.setConnectPoints(cp);
        vmc.addPlot(tempd);
    } catch (Exception e) {
        GUIHelper.showErrorMessage(getOwner(), "Error adding plot:\n" + adams.core.Utils.throwableToString(e));
        return;
    }

    ChildFrame frame = createChildFrame(vmc, GUIHelper.getDefaultDialogDimension());
    frame.setTitle(frame.getTitle() + " - " + file);
}

From source file:adams.gui.visualization.instance.LoadDatasetDialog.java

License:Open Source License

/**
 * Returns the full dataset, can be null if none loaded.
 *
 * @return      the full dataset/*w w w .ja v a  2  s.  c o  m*/
 */
public Instances getDataset() {
    int index;
    Instances result;

    result = new Instances(m_Instances);
    if (m_ComboBoxSorting.getSelectedIndex() > 0)
        result.sort(m_ComboBoxSorting.getSelectedIndex() - 1);

    index = m_ComboBoxClass.getSelectedIndex();
    if (index > -1)
        index--;
    result.setClassIndex(index);

    return result;
}

From source file:adams.ml.data.InstancesView.java

License:Open Source License

/**
 * Returns a spreadsheet containing only the input columns, not class
 * columns./*from ww  w . ja  va 2  s .c om*/
 *
 * @return      the input features, null if data conists only of class columns
 */
@Override
public SpreadSheet getInputs() {
    Instances data;

    if (m_Data.classIndex() == -1)
        return this;

    data = new Instances(m_Data);
    data.setClassIndex(-1);
    data.deleteAttributeAt(m_Data.classIndex());

    return new InstancesView(data);
}

From source file:adams.ml.data.InstancesView.java

License:Open Source License

/**
 * Returns a spreadsheet containing only output columns, i.e., the class
 * columns.//from  w ww  .ja v a2  s.c o  m
 *
 * @return      the output features, null if data has no class columns
 */
@Override
public SpreadSheet getOutputs() {
    Instances data;
    Remove remove;

    if (m_Data.classIndex() == -1)
        return null;

    data = new Instances(m_Data);
    data.setClassIndex(-1);
    remove = new Remove();
    remove.setAttributeIndicesArray(new int[] { m_Data.classIndex() });
    remove.setInvertSelection(true);
    try {
        remove.setInputFormat(data);
        data = Filter.useFilter(data, remove);
        return new InstancesView(data);
    } catch (Exception e) {
        throw new IllegalStateException("Failed to apply Remove filter!", e);
    }
}

From source file:adams.ml.data.WekaConverter.java

License:Open Source License

/**
 * Converts an ADAMS Dataset to Weka Instances.
 * Only sets the class attribute if the dataset has exactly one defined.
 *
 * @param data   the data to convert/*  w  ww.  java2s . c  o m*/
 * @return      the generated data
 * @throws Exception   if conversion fails
 */
public static Instances toInstances(Dataset data) throws Exception {
    Instances result;
    SpreadSheetToWekaInstances conv;
    String msg;
    int[] classes;

    conv = new SpreadSheetToWekaInstances();
    conv.setInput(data);
    msg = conv.convert();
    if (msg != null)
        throw new Exception("Failed to convert Dataset to Instances: " + msg);

    result = (Instances) conv.getOutput();
    conv.cleanUp();

    classes = data.getClassAttributeIndices();
    if (classes.length == 1)
        result.setClassIndex(classes[0]);

    return result;
}

From source file:adams.opt.cso.HermioneSimple.java

License:Open Source License

/**
 * For testing only./*from   w  ww  . j  a v a  2 s . co  m*/
 *
 * @param args   the dataset to use
 * @throws Exception   if something fails
 */
public static void main(String[] args) throws Exception {
    Environment.setEnvironmentClass(Environment.class);
    Instances data = DataSource.read(args[0]);
    if (data.classIndex() == -1)
        data.setClassIndex(data.numAttributes() - 1);
    MaxIterationsWithoutImprovement stopping = new MaxIterationsWithoutImprovement();
    stopping.setNumIterations(2);
    stopping.setMinimumImprovement(0.001);
    stopping.setLoggingLevel(LoggingLevel.INFO);
    HermioneSimple simple = new HermioneSimple();
    simple.setEvalParallel(true);
    simple.setMeasure(Measure.CC);
    simple.setStopping(stopping);
    simple.setLoggingLevel(LoggingLevel.INFO);
    simple.setInstances(data);
    /*
    simple.setClassifier(new GPD());
    simple.setHandlers(new AbstractCatSwarmOptimizationDiscoveryHandler[]{
      new GPDGamma(),
      new GPDNoise(),
    });
    */
    LinearRegressionJ cls = new LinearRegressionJ();
    cls.setEliminateColinearAttributes(false);
    cls.setAttributeSelectionMethod(
            new SelectedTag(LinearRegressionJ.SELECTION_NONE, LinearRegressionJ.TAGS_SELECTION));
    simple.setClassifier(new LinearRegressionJ());
    GenericDouble ridge = new GenericDouble();
    ridge.setClassname(new BaseClassname(cls.getClass()));
    ridge.setProperty("ridge");
    ridge.setMinimum(1e-8);
    ridge.setMaximum(1);
    simple.setHandlers(new AbstractCatSwarmOptimizationDiscoveryHandler[] { ridge, });
    DoubleMatrix best = simple.run();
    System.out.println(best);
}