Example usage for weka.core Instances classAttribute

List of usage examples for weka.core Instances classAttribute

Introduction

In this page you can find the example usage for weka.core Instances classAttribute.

Prototype


publicAttribute classAttribute() 

Source Link

Document

Returns the class attribute.

Usage

From source file:de.unidue.langtech.grading.tc.ClusteringTask.java

License:Open Source License

@Override
public void execute(TaskContext aContext) throws Exception {
    if (learningMode.equals(Constants.LM_MULTI_LABEL)) {
        throw new IllegalArgumentException("Cannot use multi-label setup in clustering.");
    }/* w w  w.  ja  v a2s .  co  m*/
    boolean multiLabel = false;

    File arffFileTrain = new File(
            aContext.getStorageLocation(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY).getPath() + "/"
                    + TRAINING_DATA_FILENAME);

    Instances trainData = TaskUtils.getInstances(arffFileTrain, multiLabel);

    // get number of outcomes
    List<String> trainOutcomeValues = TaskUtils.getClassLabels(trainData, multiLabel);

    Clusterer clusterer = AbstractClusterer.forName(clusteringArguments.get(0),
            clusteringArguments.subList(1, clusteringArguments.size()).toArray(new String[0]));

    Instances copyTrainData = new Instances(trainData);
    trainData = WekaUtils.removeOutcomeId(trainData, multiLabel);

    // generate data for clusterer (w/o class)
    Remove filter = new Remove();
    filter.setAttributeIndices("" + (trainData.classIndex() + 1));
    filter.setInputFormat(trainData);
    Instances clusterTrainData = Filter.useFilter(trainData, filter);

    clusterer.buildClusterer(clusterTrainData);

    // get a mapping from clusterIDs to instance offsets in the ARFF
    Map<Integer, Set<Integer>> clusterMap = getClusterMap(clusterTrainData, clusterer);

    Map<String, String> instanceId2TextMap = getInstanceId2TextMap(aContext);

    ConditionalFrequencyDistribution<Integer, String> clusterAssignments = new ConditionalFrequencyDistribution<Integer, String>();
    for (Integer clusterId : clusterMap.keySet()) {
        System.out.println("CLUSTER: " + clusterId);
        for (Integer offset : clusterMap.get(clusterId)) {

            // get instance ID from instance
            Instance instance = copyTrainData.get(offset);

            Double classOffset = new Double(instance.value(copyTrainData.classAttribute()));
            String label = (String) trainOutcomeValues.get(classOffset.intValue());

            clusterAssignments.addSample(clusterId, label);

            String instanceId = instance
                    .stringValue(copyTrainData.attribute(AddIdFeatureExtractor.ID_FEATURE_NAME).index());
            System.out.println(label + "\t" + instanceId2TextMap.get(instanceId));
        }
        System.out.println();
    }

    System.out.println("ID\tSIZE\tPURITY\tRMSE");
    for (Integer clusterId : clusterMap.keySet()) {
        FrequencyDistribution<String> fd = clusterAssignments.getFrequencyDistribution(clusterId);
        double purity = (double) fd.getCount(fd.getSampleWithMaxFreq()) / fd.getN();
        String purityString = String.format("%.2f", purity);
        double rmse = getRMSE(fd, trainOutcomeValues);
        String rmseString = String.format("%.2f", rmse);
        System.out.println(
                clusterId + "\t" + clusterMap.get(clusterId).size() + "\t" + purityString + "\t" + rmseString);
    }
    System.out.println();
}

From source file:de.unidue.langtech.grading.tc.ClusterTrainTask.java

License:Open Source License

private ConditionalFrequencyDistribution<Integer, String> getClusterCfd(Map<Integer, Set<Integer>> clusterMap,
        Instances data, List<String> outcomeValues) {
    ConditionalFrequencyDistribution<Integer, String> clusterAssignments = new ConditionalFrequencyDistribution<Integer, String>();

    for (Integer clusterId : clusterMap.keySet()) {
        for (Integer offset : clusterMap.get(clusterId)) {

            // get instance ID from instance
            Instance instance = data.get(offset);

            Double classOffset = new Double(instance.value(data.classAttribute()));
            String label = outcomeValues.get(classOffset.intValue());

            clusterAssignments.addSample(clusterId, label);
        }/* w ww . j  av  a2 s .  co m*/
    }

    return clusterAssignments;
}

From source file:de.uni_potsdam.hpi.bpt.promnicat.analysisModules.clustering.ProcessInstances.java

License:Open Source License

/**
 * Method for testing this class./*from ww w .j  a  v  a  2  s  .  co  m*/
 * 
 * @param argv
 *            should contain one element: the name of an ARFF file
 */
// @ requires argv != null;
// @ requires argv.length == 1;
// @ requires argv[0] != null;
public static void test(String[] argv) {

    ProcessInstances instances, secondInstances, train, test, empty;
    Random random = new Random(2);
    Reader reader;
    int start, num;
    FastVector testAtts, testVals;
    int i, j;

    try {
        if (argv.length > 1) {
            throw (new Exception("Usage: ProcessInstances [<filename>]"));
        }

        // Creating set of instances from scratch
        testVals = new FastVector(2);
        testVals.addElement("first_value");
        testVals.addElement("second_value");
        testAtts = new FastVector(2);
        testAtts.addElement(new Attribute("nominal_attribute", testVals));
        testAtts.addElement(new Attribute("numeric_attribute"));
        instances = new ProcessInstances("test_set", testAtts, new FastVector(), 10);
        instances.addInstance(new ProcessInstance(instances.numAttributes()));
        instances.addInstance(new ProcessInstance(instances.numAttributes()));
        instances.addInstance(new ProcessInstance(instances.numAttributes()));
        instances.setClassIndex(0);
        System.out.println("\nSet of instances created from scratch:\n");
        System.out.println(instances);

        if (argv.length == 1) {
            String filename = argv[0];
            reader = new FileReader(filename);

            // Read first five instances and print them
            System.out.println("\nFirst five instances from file:\n");
            instances = new ProcessInstances(reader, 1);
            instances.setClassIndex(instances.numAttributes() - 1);
            i = 0;
            while ((i < 5) && (instances.readInstance(reader))) {
                i++;
            }
            System.out.println(instances);

            // Read all the instances in the file
            reader = new FileReader(filename);
            instances = new ProcessInstances(reader);

            // Make the last attribute be the class
            instances.setClassIndex(instances.numAttributes() - 1);

            // Print header and instances.
            System.out.println("\nDataset:\n");
            System.out.println(instances);
            System.out.println("\nClass index: " + instances.classIndex());
        }

        // Test basic methods based on class index.
        System.out.println("\nClass name: " + instances.classAttribute().name());
        System.out.println("\nClass index: " + instances.classIndex());
        System.out.println("\nClass is nominal: " + instances.classAttribute().isNominal());
        System.out.println("\nClass is numeric: " + instances.classAttribute().isNumeric());
        System.out.println("\nClasses:\n");
        for (i = 0; i < instances.numClasses(); i++) {
            System.out.println(instances.classAttribute().value(i));
        }
        System.out.println("\nClass values and labels of instances:\n");
        for (i = 0; i < instances.numInstances(); i++) {
            ProcessInstance inst = instances.getInstance(i);
            System.out.print(inst.classValue() + "\t");
            System.out.print(inst.toString(inst.classIndex()));
            if (instances.getInstance(i).classIsMissing()) {
                System.out.println("\tis missing");
            } else {
                System.out.println();
            }
        }

        // Create random weights.
        System.out.println("\nCreating random weights for instances.");
        for (i = 0; i < instances.numInstances(); i++) {
            instances.getInstance(i).setWeight(random.nextDouble());
        }

        // Print all instances and their weights (and the sum of weights).
        System.out.println("\nInstances and their weights:\n");
        System.out.println(instances.instancesAndWeights());
        System.out.print("\nSum of weights: ");
        System.out.println(instances.sumOfWeights());

        // Insert an attribute
        secondInstances = new ProcessInstances(instances);
        Attribute testAtt = new Attribute("Inserted");
        secondInstances.insertAttributeAt(testAtt, 0);
        System.out.println("\nSet with inserted attribute:\n");
        System.out.println(secondInstances);
        System.out.println("\nClass name: " + secondInstances.classAttribute().name());

        // Delete the attribute
        secondInstances.deleteAttributeAt(0);
        System.out.println("\nSet with attribute deleted:\n");
        System.out.println(secondInstances);
        System.out.println("\nClass name: " + secondInstances.classAttribute().name());

        // Test if headers are equal
        System.out.println("\nHeaders equal: " + instances.equalHeaders(secondInstances) + "\n");

        // Print data in internal format.
        System.out.println("\nData (internal values):\n");
        for (i = 0; i < instances.numInstances(); i++) {
            for (j = 0; j < instances.numAttributes(); j++) {
                if (instances.getInstance(i).isMissing(j)) {
                    System.out.print("? ");
                } else {
                    System.out.print(instances.getInstance(i).value(j) + " ");
                }
            }
            System.out.println();
        }

        // Just print header
        System.out.println("\nEmpty dataset:\n");
        empty = new ProcessInstances(instances, 0);
        System.out.println(empty);
        System.out.println("\nClass name: " + empty.classAttribute().name());

        // Create copy and rename an attribute and a value (if possible)
        if (empty.classAttribute().isNominal()) {
            Instances copy = new ProcessInstances(empty, 0);
            copy.renameAttribute(copy.classAttribute(), "new_name");
            copy.renameAttributeValue(copy.classAttribute(), copy.classAttribute().value(0), "new_val_name");
            System.out.println("\nDataset with names changed:\n" + copy);
            System.out.println("\nOriginal dataset:\n" + empty);
        }

        // Create and prints subset of instances.
        start = instances.numInstances() / 4;
        num = instances.numInstances() / 2;
        System.out.print("\nSubset of dataset: ");
        System.out.println(num + " instances from " + (start + 1) + ". instance");
        secondInstances = new ProcessInstances(instances, start, num);
        System.out.println("\nClass name: " + secondInstances.classAttribute().name());

        // Print all instances and their weights (and the sum of weights).
        System.out.println("\nInstances and their weights:\n");
        System.out.println(secondInstances.instancesAndWeights());
        System.out.print("\nSum of weights: ");
        System.out.println(secondInstances.sumOfWeights());

        // Create and print training and test sets for 3-fold
        // cross-validation.
        System.out.println("\nTrain and test folds for 3-fold CV:");
        if (instances.classAttribute().isNominal()) {
            instances.stratify(3);
        }
        for (j = 0; j < 3; j++) {
            train = instances.trainCV(3, j, new Random(1));
            test = instances.testCV(3, j);

            // Print all instances and their weights (and the sum of
            // weights).
            System.out.println("\nTrain: ");
            System.out.println("\nInstances and their weights:\n");
            System.out.println(train.instancesAndWeights());
            System.out.print("\nSum of weights: ");
            System.out.println(train.sumOfWeights());
            System.out.println("\nClass name: " + train.classAttribute().name());
            System.out.println("\nTest: ");
            System.out.println("\nInstances and their weights:\n");
            System.out.println(test.instancesAndWeights());
            System.out.print("\nSum of weights: ");
            System.out.println(test.sumOfWeights());
            System.out.println("\nClass name: " + test.classAttribute().name());
        }

        // Randomize instances and print them.
        System.out.println("\nRandomized dataset:");
        instances.randomize(random);

        // Print all instances and their weights (and the sum of weights).
        System.out.println("\nInstances and their weights:\n");
        System.out.println(instances.instancesAndWeights());
        System.out.print("\nSum of weights: ");
        System.out.println(instances.sumOfWeights());

        // Sort instances according to first attribute and
        // print them.
        System.out.print("\nInstances sorted according to first attribute:\n ");
        instances.sort(0);

        // Print all instances and their weights (and the sum of weights).
        System.out.println("\nInstances and their weights:\n");
        System.out.println(instances.instancesAndWeights());
        System.out.print("\nSum of weights: ");
        System.out.println(instances.sumOfWeights());
    } catch (Exception e) {
        e.printStackTrace();
    }
}

From source file:decisiontree.MyC45.java

/**
* Method for building an C45 tree./*from  w w  w . ja va2s. c om*/
*
* @param instances the training data
* @exception Exception if decision tree can't be built successfully
*/
private void makeTree(Instances instances) throws Exception {

    // Check if no instances have reached this node.
    if (instances.numInstances() == 0) {
        m_Attribute = null;
        m_ClassValue = Instance.missingValue();
        m_Distribution = new double[instances.numClasses()];
        return;
    }

    // Compute attribute with maximum gain ratio.
    double[] gainRatios = new double[instances.numAttributes()];
    Enumeration attrEnum = instances.enumerateAttributes();
    while (attrEnum.hasMoreElements()) {
        Attribute attr = (Attribute) attrEnum.nextElement();
        if (attr.isNominal()) {
            gainRatios[attr.index()] = computeGainRatio(instances, attr);
        } else if (attr.isNumeric()) {
            gainRatios[attr.index()] = computeGainRatio(instances, attr, computeThreshold(instances, attr));
        }
    }
    m_Attribute = instances.attribute(Utils.maxIndex(gainRatios));

    // Make leaf if gain ratio is zero. 
    // Otherwise create successors.
    if (Utils.eq(gainRatios[m_Attribute.index()], 0)) {
        m_Attribute = null;
        m_Distribution = new double[instances.numClasses()];
        Enumeration instEnum = instances.enumerateInstances();
        while (instEnum.hasMoreElements()) {
            Instance inst = (Instance) instEnum.nextElement();
            m_Distribution[(int) inst.classValue()]++;
        }
        Utils.normalize(m_Distribution);
        m_ClassValue = Utils.maxIndex(m_Distribution);
        m_ClassAttribute = instances.classAttribute();
    } else {
        Instances[] splitData = null;
        int child = 0;
        if (m_Attribute.isNominal()) {
            child = m_Attribute.numValues();
            splitData = splitData(instances, m_Attribute);
        } else if (m_Attribute.isNumeric()) {
            child = 2;
            splitData = splitData(instances, m_Attribute, computeThreshold(instances, m_Attribute));
        }
        m_Successors = new MyC45[child];
        for (int j = 0; j < child; j++) {
            m_Successors[j] = new MyC45();
            m_Successors[j].makeTree(splitData[j]);
        }
    }
}

From source file:decisiontree.MyID3.java

private void makeTree(Instances data) {
    // Check if no instances have reached this node.  
    if (data.numInstances() == 0) {
        splitAttr = null;/*  ww w  . j a  va  2 s.c  o  m*/
        leafValue = Double.NaN;
        leafDist = new double[data.numClasses()];
        return;
    }

    if (data.numDistinctValues(data.classIndex()) == 1) {
        leafValue = data.firstInstance().classValue();
        return;
    }

    // Compute attribute with maximum information gain.  
    double[] infoGains = new double[data.numAttributes()];
    Enumeration attEnum = data.enumerateAttributes();
    while (attEnum.hasMoreElements()) {
        Attribute att = (Attribute) attEnum.nextElement();
        infoGains[att.index()] = computeInfoGain(data, att);
    }
    splitAttr = data.attribute(maxIndex(infoGains));

    // Make leaf if information gain is zero.   
    // Otherwise create successors.  
    if (Utils.eq(infoGains[splitAttr.index()], 0)) {
        splitAttr = null;
        leafDist = new double[data.numClasses()];
        Enumeration instEnum = data.enumerateInstances();
        while (instEnum.hasMoreElements()) {
            Instance inst = (Instance) instEnum.nextElement();
            leafDist[(int) inst.classValue()]++;
        }
        normalize(leafDist);
        leafValue = Utils.maxIndex(leafDist);
        classAttr = data.classAttribute();
    } else {
        Instances[] splitData = splitData(data, splitAttr);
        child = new MyID3[splitAttr.numValues()];
        for (int j = 0; j < splitAttr.numValues(); j++) {
            child[j] = new MyID3();
            child[j].makeTree(splitData[j]);
        }
    }
}

From source file:dewaweebtreeclassifier.Veranda.java

/**
 * /*  w  w  w.ja  v  a  2s  . com*/
 * @param data 
 * @throws java.lang.Exception 
 */
@Override
public void buildClassifier(Instances data) throws Exception {
    if (!data.classAttribute().isNominal())
        throw new Exception("The class attribute is not nominal.");

    if (!isAllNominalAttributes(data))
        throw new Exception("An attribute has non-nominal value.");

    if (isHaveMissingAttributes(data))
        throw new Exception("An instance has missing value(s). ID3 does not support missing values.");

    mRoot.buildClassifier(data);
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.EvaluateSimilarityFunction.java

License:Open Source License

public static Instances transformInstances(final Instances src, final CoordinateTransform transform) {
    final ArrayList<Attribute> out_attributes = new ArrayList<Attribute>();
    for (int i = 0; i < transform.outDimension(); ++i) {
        out_attributes.add(new Attribute("x" + i));
    }/*from  w  w  w  .  j  a va  2s.  c  om*/
    out_attributes.add((Attribute) src.classAttribute().copy());
    final Instances out = new Instances(src.relationName() + "_" + transform.name(), out_attributes, 0);
    for (int i = 0; i < src.size(); ++i) {
        final Instance inst = src.get(i);
        final RealVector flat = new ArrayRealVector(WekaUtil.unlabeledFeatures(inst));
        final RealVector transformed_vector = transform.encode(flat).x;
        final double[] transformed = new double[transformed_vector.getDimension() + 1];
        for (int j = 0; j < transformed_vector.getDimension(); ++j) {
            transformed[j] = transformed_vector.getEntry(j);
        }
        transformed[transformed.length - 1] = inst.classValue();
        final Instance transformed_instance = new DenseInstance(inst.weight(), transformed);
        out.add(transformed_instance);
        transformed_instance.setDataset(out);
    }
    out.setClassIndex(out.numAttributes() - 1);
    return out;
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.EvaluateSimilarityFunction.java

License:Open Source License

/**
 * @param args/*from w w w. jav a  2 s.  c o  m*/
 * @throws IOException
 * @throws FileNotFoundException
 */
public static void main(final String[] args) throws FileNotFoundException, IOException {
    final String experiment_file = args[0];
    final File root_directory;
    if (args.length > 1) {
        root_directory = new File(args[1]);
    } else {
        root_directory = new File(".");
    }
    final CsvConfigurationParser csv_config = new CsvConfigurationParser(new FileReader(experiment_file));
    final String experiment_name = FilenameUtils.getBaseName(experiment_file);

    final File expr_directory = new File(root_directory, experiment_name);
    expr_directory.mkdirs();

    final Csv.Writer csv = new Csv.Writer(
            new PrintStream(new FileOutputStream(new File(expr_directory, "results.csv"))));
    final String[] parameter_headers = new String[] { "kpca.kernel", "kpca.rbf.sigma",
            "kpca.random_forest.Ntrees", "kpca.random_forest.max_depth", "kpca.Nbases", "multiclass.classifier",
            "multiclass.random_forest.Ntrees", "multiclass.random_forest.max_depth",
            "pairwise_classifier.max_branching", "training.label_noise" };
    csv.cell("domain").cell("abstraction");
    for (final String p : parameter_headers) {
        csv.cell(p);
    }
    csv.cell("Ntrain").cell("Ntest").cell("ami.mean").cell("ami.variance").cell("ami.confidence").newline();

    for (int expr = 0; expr < csv_config.size(); ++expr) {
        try {
            final KeyValueStore expr_config = csv_config.get(expr);
            final Configuration config = new Configuration(root_directory.getPath(), expr_directory.getName(),
                    expr_config);

            System.out.println("[Loading '" + config.training_data_single + "']");
            final Instances single = WekaUtil
                    .readLabeledDataset(new File(root_directory, config.training_data_single + ".arff"));

            final Instances train = new Instances(single, 0);
            final int[] idx = Fn.range(0, single.size());
            int instance_counter = 0;
            Fn.shuffle(config.rng, idx);
            final int Ntrain = config.getInt("Ntrain_games"); // TODO: Rename?
            final double label_noise = config.getDouble("training.label_noise");
            final int Nlabels = train.classAttribute().numValues();
            assert (Nlabels > 0);
            for (int i = 0; i < Ntrain; ++i) {
                final Instance inst = single.get(idx[instance_counter++]);
                if (label_noise > 0 && config.rng.nextDouble() < label_noise) {
                    int noisy_label = 0;
                    do {
                        noisy_label = config.rng.nextInt(Nlabels);
                    } while (noisy_label == (int) inst.classValue());
                    System.out.println("Noisy label (" + inst.classValue() + " -> " + noisy_label + ")");
                    inst.setClassValue(noisy_label);
                }
                train.add(inst);
                inst.setDataset(train);
            }

            final Fn.Function2<Boolean, Instance, Instance> plausible_p = createPlausiblePredicate(config);

            final int Ntest = config.Ntest_games;
            int Ntest_added = 0;
            final ArrayList<Instances> tests = new ArrayList<Instances>();
            while (instance_counter < single.size() && Ntest_added < Ntest) {
                final Instance inst = single.get(idx[instance_counter++]);
                boolean found = false;
                for (final Instances test : tests) {
                    // Note that 'plausible_p' should be transitive
                    if (plausible_p.apply(inst, test.get(0))) {
                        WekaUtil.addInstance(test, inst);
                        if (test.size() == 30) {
                            Ntest_added += test.size();
                        } else if (test.size() > 30) {
                            Ntest_added += 1;
                        }
                        found = true;
                        break;
                    }
                }

                if (!found) {
                    final Instances test = new Instances(single, 0);
                    WekaUtil.addInstance(test, inst);
                    tests.add(test);
                }
            }
            final Iterator<Instances> test_itr = tests.iterator();
            while (test_itr.hasNext()) {
                if (test_itr.next().size() < 30) {
                    test_itr.remove();
                }
            }
            System.out.println("=== tests.size() = " + tests.size());
            System.out.println("=== Ntest_added = " + Ntest_added);

            System.out.println("[Training]");
            final Evaluator evaluator = createEvaluator(config, train);
            //            final Instances transformed_test = evaluator.prepareInstances( test );

            System.out.println("[Evaluating]");

            final int Nxval = evaluator.isSensitiveToOrdering() ? 10 : 1;
            final MeanVarianceAccumulator ami = new MeanVarianceAccumulator();

            final MeanVarianceAccumulator errors = new MeanVarianceAccumulator();
            final MeanVarianceAccumulator relative_error = new MeanVarianceAccumulator();

            int c = 0;
            for (int xval = 0; xval < Nxval; ++xval) {
                for (final Instances test : tests) {
                    // TODO: Debugging
                    WekaUtil.writeDataset(new File(config.root_directory), "test_" + (c++), test);

                    //               transformed_test.randomize( new RandomAdaptor( config.rng ) );
                    //               final ClusterContingencyTable ct = evaluator.evaluate( transformed_test );
                    test.randomize(new RandomAdaptor(config.rng));
                    final ClusterContingencyTable ct = evaluator.evaluate(test);
                    System.out.println(ct);

                    int Nerrors = 0;
                    final MeanVarianceAccumulator mv = new MeanVarianceAccumulator();
                    for (int i = 0; i < ct.R; ++i) {
                        final int max = Fn.max(ct.n[i]);
                        Nerrors += (ct.a[i] - max);
                        mv.add(((double) ct.a[i]) / ct.N * Nerrors / ct.a[i]);
                    }
                    errors.add(Nerrors);
                    relative_error.add(mv.mean());

                    System.out.println("exemplar: " + test.get(0));
                    System.out.println("Nerrors = " + Nerrors);
                    final PrintStream ct_out = new PrintStream(
                            new FileOutputStream(new File(expr_directory, "ct_" + expr + "_" + xval + ".csv")));
                    ct.writeCsv(ct_out);
                    ct_out.close();
                    final double ct_ami = ct.adjustedMutualInformation_max();
                    if (Double.isNaN(ct_ami)) {
                        System.out.println("! ct_ami = NaN");
                    } else {
                        ami.add(ct_ami);
                    }
                    System.out.println();
                }
            }
            System.out.println("errors = " + errors.mean() + " (" + errors.confidence() + ")");
            System.out.println(
                    "relative_error = " + relative_error.mean() + " (" + relative_error.confidence() + ")");
            System.out.println("AMI_max = " + ami.mean() + " (" + ami.confidence() + ")");

            csv.cell(config.domain).cell(config.get("abstraction.discovery"));
            for (final String p : parameter_headers) {
                csv.cell(config.get(p));
            }
            csv.cell(Ntrain).cell(Ntest).cell(ami.mean()).cell(ami.variance()).cell(ami.confidence()).newline();
        } catch (final Exception ex) {
            ex.printStackTrace();
        }
    }
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.WekaUtil.java

License:Open Source License

public static Pair<ArrayList<double[]>, int[]> splitLabels(final Instances train) {
    assert (train.classAttribute() != null);

    final ArrayList<double[]> X = new ArrayList<double[]>();
    final int[] Y = new int[train.size()];

    for (int i = 0; i < train.size(); ++i) {
        final Instance inst = train.get(i);
        final double[] x = new double[train.numAttributes() - 1];
        int idx = 0;
        for (int j = 0; j < train.numAttributes(); ++j) {
            if (j == train.classIndex()) {
                Y[i] = (int) inst.classValue();
            } else {
                x[idx++] = inst.value(j);
            }/*from  w w w. ja  va  2s .  c  o  m*/
        }
        X.add(x);
    }

    return Pair.makePair(X, Y);
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.WekaUtil.java

License:Open Source License

public static Instances powerSet(final Instances D, final int n) {
    final Attribute class_attr = D.classAttribute();

    final ImmutableSet.Builder<Integer> b = new ImmutableSet.Builder<Integer>();
    final int Nattr = class_attr != null ? D.numAttributes() - 1 : D.numAttributes();
    for (final int i : Fn.range(1, Nattr)) {
        b.add(i);//from ww  w .ja v a  2  s.  c  o m
    }
    final Set<Set<Integer>> index = Sets.powerSet(b.build());

    final ArrayList<Attribute> attributes = new ArrayList<Attribute>();
    for (final Set<Integer> subset : index) {
        if (subset.isEmpty() || subset.size() > n) {
            continue;
        }

        final StringBuilder attr_name = new StringBuilder();
        int count = 0;
        for (final Integer i : subset) {
            if (count++ > 0) {
                attr_name.append("_x_");
            }
            attr_name.append(D.attribute(i).name());
        }

        attributes.add(new Attribute(attr_name.toString()));
    }
    if (class_attr != null) {
        assert (class_attr.isNominal());
        attributes.add(WekaUtil.createNominalAttribute(class_attr.name(), class_attr.numValues()));
    }

    final String Pname = "P" + n + "_" + D.relationName();
    final Instances P = new Instances(Pname, attributes, 0);
    if (class_attr != null) {
        P.setClassIndex(attributes.size() - 1);
    }

    for (final Instance inst : D) {
        final double[] xp = new double[attributes.size()];
        int idx = 0;
        for (final Set<Integer> subset : index) {
            if (subset.isEmpty() || subset.size() > n) {
                continue;
            }

            double p = 1.0;
            for (final Integer i : subset) {
                p *= inst.value(i);
            }
            xp[idx++] = p;
        }
        if (class_attr != null) {
            xp[idx++] = inst.classValue();
        }

        WekaUtil.addInstance(P, new DenseInstance(inst.weight(), xp));
    }

    return P;
}