Example usage for weka.core Instances setClass

List of usage examples for weka.core Instances setClass

Introduction

In this page you can find the example usage for weka.core Instances setClass.

Prototype

public void setClass(Attribute att) 

Source Link

Document

Sets the class attribute.

Usage

From source file:lu.lippmann.cdb.datasetview.tasks.UnsupervisedFeatureSelectionTask.java

License:Open Source License

/**
 * {@inheritDoc}/*from w  w  w .  ja va 2  s .  com*/
 */
@Override
Instances process0(final Instances dataSet) throws Exception {
    final int k;
    if (this.ratio == -1)
        k = getFeaturesCountFromInput(null, dataSet.numAttributes());
    else
        k = (int) Math.round(this.ratio * dataSet.numAttributes());

    final List<Integer> attrToKeep = WekaMachineLearningUtil.computeUnsupervisedFeaturesSelection(dataSet, k);
    if (!attrToKeep.contains(dataSet.classIndex()))
        attrToKeep.add(dataSet.classIndex());
    final int[] array = ArraysUtil.transform(attrToKeep);

    System.out.println("unsupervised fs -> before=" + dataSet.numAttributes() + " after=" + array.length);

    final Instances newds = WekaDataProcessingUtil.buildFilteredByAttributesDataSet(dataSet, array);
    final Attribute clsAttr = newds.attribute(dataSet.classAttribute().name());
    System.out.println(clsAttr + " " + dataSet.classAttribute().name());
    newds.setClass(clsAttr);
    return newds;
}

From source file:mlflex.learners.WekaLearner.java

License:Open Source License

/** Creates Weka instances from ML-Flex collections.
 *
 * @param dependentVariableInstances ML-Flex collection of dataInstances
 * @return Weka instances/*from  www  .  ja va2 s.  c  o  m*/
 * @throws Exception
 */
private static Instances GetEvaluationInstances(Predictions predictions) throws Exception {
    FastVector wekaAttributeVector = GetAttributeVector(predictions);

    Instances wekaInstances = new Instances("DataSet", wekaAttributeVector, predictions.Size());
    wekaInstances.setClass((Attribute) wekaAttributeVector.elementAt(1));

    for (Prediction prediction : predictions.GetAll())
        wekaInstances.add(GetInstance(wekaInstances, wekaAttributeVector, prediction));

    return wekaInstances;
}

From source file:mlflex.WekaInMemoryLearner.java

License:Open Source License

/** Creates Weka instances from ML-Flex collections.
 *
 *
 * @param dependentVariableInstances Dependent variable data instances
 * @param attVector Vector of Weka attributes
 * @param instances ML-Flex collection of instances
 * @return Weka instances/*  w  w w  . j a v a 2 s.c  o  m*/
 * @throws Exception
 */
public static Instances GetInstances(DataInstanceCollection dependentVariableInstances, FastVector attVector,
        DataInstanceCollection instances) throws Exception {
    Instances wekaInstances = new Instances("DataSet", attVector, instances.Size());

    if (dependentVariableInstances != null)
        wekaInstances.setClass((Attribute) attVector.elementAt(attVector.size() - 1));

    for (DataValues instance : instances)
        wekaInstances.add(GetInstance(wekaInstances, attVector, instance, dependentVariableInstances));

    return wekaInstances;
}

From source file:mulan.transformations.regression.ChainTransformation.java

License:Open Source License

/**
 * Deletes all target attributes that appear after the first targetsToKeep in the chain. The
 * target attribute at position targetsToKeep in the chain is set as the class attribute.
 * //from ww w . ja v a2 s  .c  o  m
 * @param data the input data set
 * @param chain a chain (permutation) of the indices of the target attributes
 * @param numTargetsToKeep the number of target attributes from the beginning of the chain that
 *            should be kept, 1&lt;=numTargetsToKeep&lt;=numOfTargets
 * @return the transformed Instances object. The input object is not modified.
 * @throws Exception Potential exception thrown. To be handled in an upper level.
 */
public static Instances transformInstances(Instances data, int[] chain, int numTargetsToKeep) throws Exception {
    int numOfTargets = chain.length;
    if (numTargetsToKeep < 1 || numTargetsToKeep > numOfTargets) {
        throw new Exception("keepFirstKTargets should be between 1 and numOfTargets");
    }
    // Indices of attributes to remove
    int[] indicesToRemove = new int[numOfTargets - numTargetsToKeep];
    // the indices of the target attributes whose position in the chain is
    // after the first keepFirstKTargets attributes are marked for removal
    for (int i = 0; i < numOfTargets - numTargetsToKeep; i++) {
        indicesToRemove[i] = chain[numTargetsToKeep + i];
    }

    Remove remove = new Remove();
    remove.setAttributeIndicesArray(indicesToRemove);
    remove.setInputFormat(data);
    // get the class attribute name, the name of the target attribute which is placed in the
    // targetsToKeep position of the chain
    String classAttributeName = data.attribute(chain[numTargetsToKeep - 1]).name();
    Instances transformed = Filter.useFilter(data, remove);
    transformed.setClass(transformed.attribute(classAttributeName));
    return transformed;
}

From source file:old.CFS.java

/**
 * takes a dataset as first argument/*from  w w w.  j a v a2s.  c o m*/
 *
 * @param args        the commandline arguments
 * @throws Exception  if something goes wrong
 */
public static void main(String[] args) throws Exception {
    // load data
    System.out.println("\n0. Loading data");
    DataSource source = new DataSource("D:\\ALL\\imdb_grid_size=1000_MIN=50_genres=5.arff");
    Instances data = source.getDataSet();
    data.setClass(data.attribute("Horror"));

    if (data.classIndex() == -1)
        data.setClassIndex(data.numAttributes() - 1);

    //    // 1. meta-classifier
    //    useClassifier(data);
    //
    //    // 2. filter
    //    useFilter(data);

    // 3. low-level
    useLowLevel(data);
}

From source file:org.jaqpot.algorithm.resource.WekaPLS.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {
    try {/*from   w  ww . j  a va 2s  .c o  m*/
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response
                    .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory
                            .badRequest("Dataset is empty", "Cannot make predictions on empty dataset"))
                    .build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));
        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                        .build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (IOException | ClassNotFoundException ex) {
        Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithm.resource.WekaRBF.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {

    try {/*  w ww .  jav  a  2 s  . c o m*/
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response
                    .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory
                            .badRequest("Dataset is empty", "Cannot make predictions on empty dataset"))
                    .build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));

        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka RBF prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                        .build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (IOException | ClassNotFoundException ex) {
        Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithm.resource.WekaSVM.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {
    try {//from   ww  w .j a  v a 2 s .  c o m
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response
                    .status(Response.Status.BAD_REQUEST).entity(ErrorReportFactory
                            .badRequest("Dataset is empty", "Cannot make predictions on empty dataset"))
                    .build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));
        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka SVM prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST).entity(
                        ErrorReportFactory.badRequest("Error while gettting predictions.", ex.getMessage()))
                        .build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (Exception ex) {
        Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithms.resource.WekaPLS.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {
    try {//from w  w  w.  j  a v a  2s .c  o m
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response.status(Response.Status.BAD_REQUEST)
                    .entity("Dataset is empty. Cannot make predictions on empty dataset.").build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));
        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka PLS prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST)
                        .entity("Error while gettting predictions. " + ex.getMessage()).build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (IOException | ClassNotFoundException ex) {
        Logger.getLogger(WekaSVM.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithms.resource.WekaRBF.java

License:Open Source License

@POST
@Path("prediction")
public Response prediction(PredictionRequest request) {

    try {//from   w w w  . j a  v  a 2  s.  c om
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response.status(Response.Status.BAD_REQUEST)
                    .entity("Dataset is empty. Cannot make predictions on empty dataset.").build();
        }

        String base64Model = (String) request.getRawModel();
        byte[] modelBytes = Base64.getDecoder().decode(base64Model);
        ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
        ObjectInput in = new ObjectInputStream(bais);
        WekaModel model = (WekaModel) in.readObject();

        Classifier classifier = model.getClassifier();
        Instances data = InstanceUtils.createFromDataset(request.getDataset());
        String dependentFeature = (String) request.getAdditionalInfo();
        data.insertAttributeAt(new Attribute(dependentFeature), data.numAttributes());
        data.setClass(data.attribute(dependentFeature));

        List<LinkedHashMap<String, Object>> predictions = new ArrayList<>();
        //            data.stream().forEach(instance -> {
        //                try {
        //                    double prediction = classifier.classifyInstance(instance);
        //                    Map<String, Object> predictionMap = new HashMap<>();
        //                    predictionMap.put("Weka MLR prediction of " + dependentFeature, prediction);
        //                    predictions.add(predictionMap);
        //                } catch (Exception ex) {
        //                    Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        //                }
        //            });

        for (int i = 0; i < data.numInstances(); i++) {
            Instance instance = data.instance(i);
            try {
                double prediction = classifier.classifyInstance(instance);
                LinkedHashMap<String, Object> predictionMap = new LinkedHashMap<>();
                predictionMap.put("Weka RBF prediction of " + dependentFeature, prediction);
                predictions.add(predictionMap);
            } catch (Exception ex) {
                Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
                return Response.status(Response.Status.BAD_REQUEST)
                        .entity("Error while gettting predictions. " + ex.getMessage()).build();
            }
        }

        PredictionResponse response = new PredictionResponse();
        response.setPredictions(predictions);
        return Response.ok(response).build();
    } catch (IOException | ClassNotFoundException ex) {
        Logger.getLogger(WekaMLR.class.getName()).log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}