Example usage for weka.core Instances add

List of usage examples for weka.core Instances add

Introduction

In this page you can find the example usage for weka.core Instances add.

Prototype

@Override
public boolean add(Instance instance) 

Source Link

Document

Adds one instance to the end of the set.

Usage

From source file:dkpro.similarity.experiments.sts2013baseline.util.Evaluator.java

License:Open Source License

public static void runLinearRegressionCV(Mode mode, Dataset... datasets) throws Exception {
    for (Dataset dataset : datasets) {
        // Set parameters
        int folds = 10;
        Classifier baseClassifier = new LinearRegression();

        // Set up the random number generator
        long seed = new Date().getTime();
        Random random = new Random(seed);

        // Add IDs to the instances
        AddID.main(new String[] { "-i",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff", "-o",
                MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                        + "-plusIDs.arff" });

        String location = MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString()
                + "-plusIDs.arff";

        Instances data = DataSource.read(location);

        if (data == null) {
            throw new IOException("Could not load data from: " + location);
        }/*from   w  w w.ja v  a  2s  .  com*/

        data.setClassIndex(data.numAttributes() - 1);

        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
        removeIDFilter.setAttributeIndices("first");

        // Randomize the data
        data.randomize(random);

        // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);

        for (int n = 0; n < folds; n++) {
            Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);

            // Apply log filter
            Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);

            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);

            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);

            // Build the classifier
            filteredClassifier.buildClassifier(train);

            // Evaluate
            eval.evaluateModel(classifier, test);

            // Add predictions
            AddClassification filter = new AddClassification();
            filter.setClassifier(classifier);
            filter.setOutputClassification(true);
            filter.setOutputDistribution(false);
            filter.setOutputErrorFlag(true);
            filter.setInputFormat(train);
            Filter.useFilter(train, filter); // trains the classifier

            Instances pred = Filter.useFilter(test, filter); // performs predictions on test set
            if (predictedData == null) {
                predictedData = new Instances(pred, 0);
            }
            for (int j = 0; j < pred.numInstances(); j++) {
                predictedData.add(pred.instance(j));
            }
        }

        // Prepare output scores
        double[] scores = new double[predictedData.numInstances()];

        for (Instance predInst : predictedData) {
            int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;

            int valueIdx = predictedData.numAttributes() - 2;

            double value = predInst.value(predInst.attribute(valueIdx));

            scores[id] = value;

            // Limit to interval [0;5]
            if (scores[id] > 5.0) {
                scores[id] = 5.0;
            }
            if (scores[id] < 0.0) {
                scores[id] = 0.0;
            }
        }

        // Output
        StringBuilder sb = new StringBuilder();
        for (Double score : scores) {
            sb.append(score.toString() + LF);
        }

        FileUtils.writeStringToFile(
                new File(OUTPUT_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".csv"),
                sb.toString());
    }
}

From source file:edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.MekaProvider.java

License:Apache License

@Override
public void train(List<Map<String, Double>> X, List<String> Y, boolean crossValidation)
        throws AnalysisEngineProcessException {
    // create attribute (including label) info
    ArrayList<Attribute> attributes = new ArrayList<>();
    List<String> labelNames = ClassifierProvider.labelNames(Y);
    labelNames.stream().map(attr -> new Attribute(attr, Arrays.asList("y", "n")))
            .forEachOrdered(attributes::add);
    List<String> featureNames = ClassifierProvider.featureNames(X);
    featureNames.stream().map(Attribute::new).forEachOrdered(attributes::add);
    String name = Files.getNameWithoutExtension(modelFile.getName());
    datasetSchema = new Instances(name, attributes, 0);
    datasetSchema.setClassIndex(labelNames.size());
    // add instances
    // due to the limitation of the interface definition, X, Y should be reorganized
    SetMultimap<Map<String, Double>, String> XY = HashMultimap.create();
    IntStream.range(0, X.size()).forEach(i -> XY.put(X.get(i), Y.get(i)));
    Instances trainingInstances = new Instances(datasetSchema, XY.size());
    for (Map.Entry<Map<String, Double>, Collection<String>> entry : XY.asMap().entrySet()) {
        Set<String> y = ImmutableSet.copyOf(entry.getValue());
        Map<String, Double> x = entry.getKey();
        SparseInstance instance = new SparseInstance(labelNames.size() + x.size());
        for (String labelName : labelNames) {
            instance.setValue(datasetSchema.attribute(labelName), y.contains(labelName) ? "y" : "n");
        }//from   w  w w .  j  a v  a 2  s. co  m
        for (Map.Entry<String, Double> e : x.entrySet()) {
            instance.setValue(datasetSchema.attribute(e.getKey()), e.getValue());
        }
        trainingInstances.add(instance);
    }
    // training
    try {
        classifier = (MultiLabelClassifier) AbstractClassifier.forName(classifierName, options);
        classifier.buildClassifier(trainingInstances);
    } catch (Exception e) {
        throw new AnalysisEngineProcessException(e);
    }
    try {
        SerializationHelper.write(modelFile.getAbsolutePath(), classifier);
        SerializationHelper.write(datasetSchemaFile.getAbsolutePath(), datasetSchema);
    } catch (Exception e) {
        throw new AnalysisEngineProcessException(e);
    }
    if (crossValidation) {
        try {
            Evaluation eval = new Evaluation(trainingInstances);
            Random rand = new Random();
            eval.crossValidateModel(classifier, trainingInstances, 10, rand);
            LOG.debug(eval.toSummaryString());
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
    }
}

From source file:edu.cmu.lti.oaqa.baseqa.providers.ml.classifiers.WekaProvider.java

License:Apache License

@Override
public void train(List<Map<String, Double>> X, List<String> Y, boolean crossValidation)
        throws AnalysisEngineProcessException {
    // create attribute (including label) info
    ArrayList<Attribute> attributes = new ArrayList<>();
    ClassifierProvider.featureNames(X).stream().map(Attribute::new).forEachOrdered(attributes::add);
    Attribute label = new Attribute("__label__", ClassifierProvider.labelNames(Y));
    attributes.add(label);/*from w ww . j  a v  a  2 s.  c  o m*/
    String name = Files.getNameWithoutExtension(modelFile.getName());
    datasetSchema = new Instances(name, attributes, X.size());
    datasetSchema.setClass(label);
    // add instances
    Instances trainingInstances = new Instances(datasetSchema, X.size());
    if (balanceWeight) {
        Multiset<String> labelCounts = HashMultiset.create(Y);
        double maxCount = labelCounts.entrySet().stream().mapToInt(Multiset.Entry::getCount).max()
                .orElseThrow(AnalysisEngineProcessException::new);
        for (int i = 0; i < X.size(); i++) {
            String y = Y.get(i);
            double weight = maxCount / labelCounts.count(y);
            trainingInstances.add(newInstance(X.get(i), y, weight, trainingInstances));
        }
    } else {
        for (int i = 0; i < X.size(); i++) {
            trainingInstances.add(newInstance(X.get(i), Y.get(i), 1.0, trainingInstances));
        }
    }
    // training
    try {
        classifier = AbstractClassifier.forName(classifierName, options);
        classifier.buildClassifier(trainingInstances);
    } catch (Exception e) {
        throw new AnalysisEngineProcessException(e);
    }
    // write model and dataset schema
    try {
        SerializationHelper.write(modelFile.getAbsolutePath(), classifier);
        SerializationHelper.write(datasetSchemaFile.getAbsolutePath(), datasetSchema);
    } catch (Exception e) {
        throw new AnalysisEngineProcessException(e);
    }
    // backup training dataset as arff file
    if (datasetExportFile != null) {
        try {
            ArffSaver saver = new ArffSaver();
            saver.setInstances(trainingInstances);
            saver.setFile(datasetExportFile);
            saver.writeBatch();
        } catch (IOException e) {
            throw new AnalysisEngineProcessException(e);
        }
    }
    if (crossValidation) {
        try {
            Evaluation eval = new Evaluation(trainingInstances);
            Random rand = new Random();
            eval.crossValidateModel(classifier, trainingInstances, 10, rand);
            LOG.debug(eval.toSummaryString());
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
    }
}

From source file:edu.cuny.qc.speech.AuToBI.util.ClassifierUtils.java

License:Open Source License

/**
 * Converts a feature set object to a weka Instances object
 * <p/>//from w  w  w  . ja v  a 2s . co m
 * The class is set to the last attribute.
 *
 * @param feature_set the feature set to convert
 * @return a weka instances object
 * @throws Exception If the arff file can't be written or read.
 */
public static Instances convertFeatureSetToWekaInstances(FeatureSet feature_set) throws Exception {
    ArrayList<Attribute> attributes = generateWekaAttributes(feature_set.getFeatures());
    Instances instances = new Instances("AuToBI_feature_set", attributes, feature_set.getDataPoints().size());
    for (Word w : feature_set.getDataPoints()) {
        Instance inst = ClassifierUtils.assignWekaAttributes(instances, w);
        instances.add(inst);
    }

    ClassifierUtils.setWekaClassAttribute(instances, feature_set.getClassAttribute());
    return instances;
}

From source file:edu.cuny.qc.speech.AuToBI.util.ClassifierUtils.java

License:Open Source License

/**
 * Converts a feature set object to a weka Instances object.
 * <p/>/*from  ww  w  .j  a  va  2s .c o m*/
 * Use wekas instance weighting capability to assign weights for each data point.
 *
 * @param feature_set the feature set to convert
 * @param fn          a weight function
 * @return a weka instances object
 */
public static Instances convertFeatureSetToWeightedWekaInstances(FeatureSet feature_set, WeightFunction fn) {
    ArrayList<Attribute> attributes = generateWekaAttributes(feature_set.getFeatures());
    Instances instances = new Instances("AuToBI_feature_set", attributes, feature_set.getDataPoints().size());
    for (Word w : feature_set.getDataPoints()) {
        Instance inst = ClassifierUtils.assignWekaAttributes(instances, w);
        inst.setWeight(fn.weight(w));
        instances.add(inst);
    }

    ClassifierUtils.setWekaClassAttribute(instances, feature_set.getClassAttribute());
    return instances;
}

From source file:edu.insight.finlaw.multilabel.rough.CreateInstances.java

License:Open Source License

/**
 * Generates the Instances object and outputs it in ARFF format to stdout.
 *
 * @param args   ignored/*from   w w w.  j a v a2  s  .c om*/
 * @throws Exception   if generation of instances fails
 */
public static void main(String[] args) throws Exception {
    ArrayList<Attribute> atts;
    ArrayList<Attribute> attsRel;
    ArrayList<String> attVals;
    ArrayList<String> attValsRel;
    Instances data;
    Instances dataRel;
    double[] vals;
    double[] valsRel;
    int i;

    // 1. set up attributes
    atts = new ArrayList<Attribute>();
    // - numeric
    atts.add(new Attribute("att1"));
    // - nominal
    attVals = new ArrayList<String>();
    for (i = 0; i < 5; i++)
        attVals.add("val" + (i + 1));
    atts.add(new Attribute("att2", attVals));
    // - string   
    atts.add(new Attribute("att3", (ArrayList<String>) null));
    // - date
    atts.add(new Attribute("att4", "yyyy-MM-dd"));
    // - relational
    attsRel = new ArrayList<Attribute>();
    // -- numeric
    attsRel.add(new Attribute("att5.1"));
    // -- nominal
    attValsRel = new ArrayList<String>();
    for (i = 0; i < 5; i++)
        attValsRel.add("val5." + (i + 1));
    attsRel.add(new Attribute("att5.2", attValsRel));
    dataRel = new Instances("att5", attsRel, 0);
    atts.add(new Attribute("att5", dataRel, 0));

    // 2. create Instances object
    data = new Instances("MyRelation", atts, 0);

    // 3. fill with data
    // first instance
    vals = new double[data.numAttributes()];
    // - numeric
    vals[0] = Math.PI;
    // - nominal
    vals[1] = attVals.indexOf("val3");
    // - string
    vals[2] = data.attribute(2).addStringValue("This is a string!");
    // - date
    vals[3] = data.attribute(3).parseDate("2001-11-09");
    // - relational
    dataRel = new Instances(data.attribute(4).relation(), 0);
    // -- first instance
    valsRel = new double[2];
    valsRel[0] = Math.PI + 1;
    valsRel[1] = attValsRel.indexOf("val5.3");
    dataRel.add(new DenseInstance(1.0, valsRel));
    // -- second instance
    valsRel = new double[2];
    valsRel[0] = Math.PI + 2;
    valsRel[1] = attValsRel.indexOf("val5.2");
    dataRel.add(new DenseInstance(1.0, valsRel));
    vals[4] = data.attribute(4).addRelation(dataRel);
    // add
    data.add(new DenseInstance(1.0, vals));

    // second instance
    vals = new double[data.numAttributes()]; // important: needs NEW array!
    // - numeric
    vals[0] = Math.E;
    // - nominal
    vals[1] = attVals.indexOf("val1");
    // - string
    vals[2] = data.attribute(2).addStringValue("And another one!");
    // - date
    vals[3] = data.attribute(3).parseDate("2000-12-01");
    // - relational
    dataRel = new Instances(data.attribute(4).relation(), 0);
    // -- first instance
    valsRel = new double[2];
    valsRel[0] = Math.E + 1;
    valsRel[1] = attValsRel.indexOf("val5.4");
    dataRel.add(new DenseInstance(1.0, valsRel));
    // -- second instance
    valsRel = new double[2];
    valsRel[0] = Math.E + 2;
    valsRel[1] = attValsRel.indexOf("val5.1");
    dataRel.add(new DenseInstance(1.0, valsRel));
    vals[4] = data.attribute(4).addRelation(dataRel);
    // add
    data.add(new DenseInstance(1.0, vals));

    // 4. output data
    System.out.println(data);
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.EvaluateSimilarityFunction.java

License:Open Source License

public static Instances transformInstances(final Instances src, final CoordinateTransform transform) {
    final ArrayList<Attribute> out_attributes = new ArrayList<Attribute>();
    for (int i = 0; i < transform.outDimension(); ++i) {
        out_attributes.add(new Attribute("x" + i));
    }//www  .  j a  v  a  2 s  . c om
    out_attributes.add((Attribute) src.classAttribute().copy());
    final Instances out = new Instances(src.relationName() + "_" + transform.name(), out_attributes, 0);
    for (int i = 0; i < src.size(); ++i) {
        final Instance inst = src.get(i);
        final RealVector flat = new ArrayRealVector(WekaUtil.unlabeledFeatures(inst));
        final RealVector transformed_vector = transform.encode(flat).x;
        final double[] transformed = new double[transformed_vector.getDimension() + 1];
        for (int j = 0; j < transformed_vector.getDimension(); ++j) {
            transformed[j] = transformed_vector.getEntry(j);
        }
        transformed[transformed.length - 1] = inst.classValue();
        final Instance transformed_instance = new DenseInstance(inst.weight(), transformed);
        out.add(transformed_instance);
        transformed_instance.setDataset(out);
    }
    out.setClassIndex(out.numAttributes() - 1);
    return out;
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.EvaluateSimilarityFunction.java

License:Open Source License

/**
 * @param args//  ww w . j ava2  s  .  c o  m
 * @throws IOException
 * @throws FileNotFoundException
 */
public static void main(final String[] args) throws FileNotFoundException, IOException {
    final String experiment_file = args[0];
    final File root_directory;
    if (args.length > 1) {
        root_directory = new File(args[1]);
    } else {
        root_directory = new File(".");
    }
    final CsvConfigurationParser csv_config = new CsvConfigurationParser(new FileReader(experiment_file));
    final String experiment_name = FilenameUtils.getBaseName(experiment_file);

    final File expr_directory = new File(root_directory, experiment_name);
    expr_directory.mkdirs();

    final Csv.Writer csv = new Csv.Writer(
            new PrintStream(new FileOutputStream(new File(expr_directory, "results.csv"))));
    final String[] parameter_headers = new String[] { "kpca.kernel", "kpca.rbf.sigma",
            "kpca.random_forest.Ntrees", "kpca.random_forest.max_depth", "kpca.Nbases", "multiclass.classifier",
            "multiclass.random_forest.Ntrees", "multiclass.random_forest.max_depth",
            "pairwise_classifier.max_branching", "training.label_noise" };
    csv.cell("domain").cell("abstraction");
    for (final String p : parameter_headers) {
        csv.cell(p);
    }
    csv.cell("Ntrain").cell("Ntest").cell("ami.mean").cell("ami.variance").cell("ami.confidence").newline();

    for (int expr = 0; expr < csv_config.size(); ++expr) {
        try {
            final KeyValueStore expr_config = csv_config.get(expr);
            final Configuration config = new Configuration(root_directory.getPath(), expr_directory.getName(),
                    expr_config);

            System.out.println("[Loading '" + config.training_data_single + "']");
            final Instances single = WekaUtil
                    .readLabeledDataset(new File(root_directory, config.training_data_single + ".arff"));

            final Instances train = new Instances(single, 0);
            final int[] idx = Fn.range(0, single.size());
            int instance_counter = 0;
            Fn.shuffle(config.rng, idx);
            final int Ntrain = config.getInt("Ntrain_games"); // TODO: Rename?
            final double label_noise = config.getDouble("training.label_noise");
            final int Nlabels = train.classAttribute().numValues();
            assert (Nlabels > 0);
            for (int i = 0; i < Ntrain; ++i) {
                final Instance inst = single.get(idx[instance_counter++]);
                if (label_noise > 0 && config.rng.nextDouble() < label_noise) {
                    int noisy_label = 0;
                    do {
                        noisy_label = config.rng.nextInt(Nlabels);
                    } while (noisy_label == (int) inst.classValue());
                    System.out.println("Noisy label (" + inst.classValue() + " -> " + noisy_label + ")");
                    inst.setClassValue(noisy_label);
                }
                train.add(inst);
                inst.setDataset(train);
            }

            final Fn.Function2<Boolean, Instance, Instance> plausible_p = createPlausiblePredicate(config);

            final int Ntest = config.Ntest_games;
            int Ntest_added = 0;
            final ArrayList<Instances> tests = new ArrayList<Instances>();
            while (instance_counter < single.size() && Ntest_added < Ntest) {
                final Instance inst = single.get(idx[instance_counter++]);
                boolean found = false;
                for (final Instances test : tests) {
                    // Note that 'plausible_p' should be transitive
                    if (plausible_p.apply(inst, test.get(0))) {
                        WekaUtil.addInstance(test, inst);
                        if (test.size() == 30) {
                            Ntest_added += test.size();
                        } else if (test.size() > 30) {
                            Ntest_added += 1;
                        }
                        found = true;
                        break;
                    }
                }

                if (!found) {
                    final Instances test = new Instances(single, 0);
                    WekaUtil.addInstance(test, inst);
                    tests.add(test);
                }
            }
            final Iterator<Instances> test_itr = tests.iterator();
            while (test_itr.hasNext()) {
                if (test_itr.next().size() < 30) {
                    test_itr.remove();
                }
            }
            System.out.println("=== tests.size() = " + tests.size());
            System.out.println("=== Ntest_added = " + Ntest_added);

            System.out.println("[Training]");
            final Evaluator evaluator = createEvaluator(config, train);
            //            final Instances transformed_test = evaluator.prepareInstances( test );

            System.out.println("[Evaluating]");

            final int Nxval = evaluator.isSensitiveToOrdering() ? 10 : 1;
            final MeanVarianceAccumulator ami = new MeanVarianceAccumulator();

            final MeanVarianceAccumulator errors = new MeanVarianceAccumulator();
            final MeanVarianceAccumulator relative_error = new MeanVarianceAccumulator();

            int c = 0;
            for (int xval = 0; xval < Nxval; ++xval) {
                for (final Instances test : tests) {
                    // TODO: Debugging
                    WekaUtil.writeDataset(new File(config.root_directory), "test_" + (c++), test);

                    //               transformed_test.randomize( new RandomAdaptor( config.rng ) );
                    //               final ClusterContingencyTable ct = evaluator.evaluate( transformed_test );
                    test.randomize(new RandomAdaptor(config.rng));
                    final ClusterContingencyTable ct = evaluator.evaluate(test);
                    System.out.println(ct);

                    int Nerrors = 0;
                    final MeanVarianceAccumulator mv = new MeanVarianceAccumulator();
                    for (int i = 0; i < ct.R; ++i) {
                        final int max = Fn.max(ct.n[i]);
                        Nerrors += (ct.a[i] - max);
                        mv.add(((double) ct.a[i]) / ct.N * Nerrors / ct.a[i]);
                    }
                    errors.add(Nerrors);
                    relative_error.add(mv.mean());

                    System.out.println("exemplar: " + test.get(0));
                    System.out.println("Nerrors = " + Nerrors);
                    final PrintStream ct_out = new PrintStream(
                            new FileOutputStream(new File(expr_directory, "ct_" + expr + "_" + xval + ".csv")));
                    ct.writeCsv(ct_out);
                    ct_out.close();
                    final double ct_ami = ct.adjustedMutualInformation_max();
                    if (Double.isNaN(ct_ami)) {
                        System.out.println("! ct_ami = NaN");
                    } else {
                        ami.add(ct_ami);
                    }
                    System.out.println();
                }
            }
            System.out.println("errors = " + errors.mean() + " (" + errors.confidence() + ")");
            System.out.println(
                    "relative_error = " + relative_error.mean() + " (" + relative_error.confidence() + ")");
            System.out.println("AMI_max = " + ami.mean() + " (" + ami.confidence() + ")");

            csv.cell(config.domain).cell(config.get("abstraction.discovery"));
            for (final String p : parameter_headers) {
                csv.cell(config.get(p));
            }
            csv.cell(Ntrain).cell(Ntest).cell(ami.mean()).cell(ami.variance()).cell(ami.confidence()).newline();
        } catch (final Exception ex) {
            ex.printStackTrace();
        }
    }
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.Experiments.java

License:Open Source License

/**
 * Creates a labeled dataset of states pair with optimal actions. Action
 * labels are represented as indexes into an array list. Mappings in both
 * directions are also returned./*w  w w  .  j a  v  a  2 s .c om*/
 * @param config
 * @param attributes
 * @param data
 * @param labels
 * @param iter
 * @return
 */
private static <A extends VirtualConstructor<A>> SingleInstanceDataset<A> makeSingleInstanceDataset(
        final Configuration config, final ArrayList<Attribute> attributes, final ArrayList<double[]> data,
        final ArrayList<A> labels, final ArrayList<Pair<ArrayList<A>, TDoubleList>> qtable, final int iter) {
    //      System.out.println( "data.size() = " + data.size() );
    final int[] ii = Fn.range(0, data.size());
    Fn.shuffle(config.rng, ii);

    final HashMap<A, Integer> action_to_int = new HashMap<A, Integer>();
    final ArrayList<A> int_to_action = new ArrayList<A>();
    final ArrayList<Pair<ArrayList<A>, TDoubleList>> abridged_qtable = (qtable != null
            ? new ArrayList<Pair<ArrayList<A>, TDoubleList>>()
            : null);

    final TIntArrayList counts = new TIntArrayList();
    final int max_per_label = config.getInt("training.max_per_label");
    final int max_instances = config.getInt("training.max_single");

    final ArrayList<DenseInstance> instance_list = new ArrayList<DenseInstance>();
    for (int i = 0; i < Math.min(data.size(), max_instances); ++i) {
        final int idx = ii[i];
        final A a = labels.get(idx);
        final Integer idx_obj = action_to_int.get(a);
        final int label;
        if (idx_obj == null) {
            //            System.out.println( "\tNew action: " + a );
            label = int_to_action.size();
            int_to_action.add(a);
            action_to_int.put(a, label);
            counts.add(0);
        } else {
            //            System.out.println( "\tRepeat action: " + a );
            label = idx_obj;
        }

        final int c = counts.get(label);
        if (max_per_label <= 0 || c < max_per_label) {
            //            System.out.println( "Adding " + label );
            final double[] phi = Fn.append(data.get(idx), label);
            final DenseInstance instance = new DenseInstance(1.0, phi);
            instance_list.add(instance);
            counts.set(label, c + 1);
            if (qtable != null) {
                abridged_qtable.add(qtable.get(idx));
            }
        }
    }

    final int Nlabels = int_to_action.size();
    final ArrayList<Attribute> labeled_attributes = addLabelToAttributes(attributes, Nlabels);

    final Instances instances = new Instances(deriveDatasetName(config.training_data_single, iter),
            labeled_attributes, counts.sum());
    instances.setClassIndex(instances.numAttributes() - 1);
    for (final DenseInstance instance : instance_list) {
        instances.add(instance);
        instance.setDataset(instances);
    }

    return new SingleInstanceDataset<A>(instances, action_to_int, int_to_action, abridged_qtable);
}

From source file:edu.oregonstate.eecs.mcplan.abstraction.PairDataset.java

License:Open Source License

public static <S, X extends FactoredRepresentation<S>, A extends VirtualConstructor<A>> Instances makePairDataset(
        final RandomGenerator rng, final int max_pairwise_instances, final Instances single,
        final InstanceCombiner combiner) {
    //      final int max_pairwise = config.getInt( "training.max_pairwise" );
    final ReservoirSampleAccumulator<Instance> negative = new ReservoirSampleAccumulator<Instance>(rng,
            max_pairwise_instances);/*from  w w w  . j  av a 2s.  c  o m*/
    final ReservoirSampleAccumulator<Instance> positive = new ReservoirSampleAccumulator<Instance>(rng,
            max_pairwise_instances);

    for (int i = 0; i < single.size(); ++i) {
        //         if( i % 100 == 0 ) {
        //            System.out.println( "i = " + i );
        //         }
        for (int j = i + 1; j < single.size(); ++j) {
            final Instance ii = single.get(i);
            final Instance ij = single.get(j);
            final int label;
            if (ii.classValue() == ij.classValue()) {
                label = 1;
                if (positive.acceptNext()) {
                    final Instance pair_instance = combiner.apply(ii, ij, label);
                    positive.addPending(pair_instance);
                }
            } else {
                label = 0;
                if (negative.acceptNext()) {
                    final Instance pair_instance = combiner.apply(ii, ij, label);
                    negative.addPending(pair_instance);
                }
            }
        }
    }

    final int N = Math.min(negative.samples().size(), positive.samples().size());
    final String dataset_name = "train_" + combiner.keyword() + "_" + max_pairwise_instances;
    final Instances x = new Instances(dataset_name, combiner.attributes(), 2 * N);
    x.setClassIndex(x.numAttributes() - 1);
    for (final Instance ineg : negative.samples()) {
        x.add(ineg);
    }
    for (final Instance ipos : positive.samples()) {
        x.add(ipos);
    }

    return x;
    //      return new PairDataset( x, combiner );
}