Example usage for weka.classifiers.functions PLSClassifier PLSClassifier

List of usage examples for weka.classifiers.functions PLSClassifier PLSClassifier

Introduction

In this page you can find the example usage for weka.classifiers.functions PLSClassifier PLSClassifier.

Prototype

PLSClassifier

Source Link

Usage

From source file:adams.opt.optimise.GeneticAlgorithm.java

License:Open Source License

public static void main(String[] args) {
    Environment.setEnvironmentClass(Environment.class);
    GeneticAlgorithm ga = new GeneticAlgorithm();
    ga.setBits(1);/*from   ww w .  j  a v a  2 s . c  o  m*/
    ga.setNumChrom(8);
    ga.setIterations(10000);
    ga.setFavorZeroes(true);

    AttributeSelection as = new AttributeSelection();
    //as.setDataset(new PlaceholderFile("/home/dale/blgg/conversion/merged/m_5_.75.arff"));
    ArrayConsumer.setOptions(as, args);
    PLSClassifier pls = new PLSClassifier();
    PLSFilter pf = (PLSFilter) pls.getFilter();
    pf.setNumComponents(11);

    LinearRegressionJ reg = new LinearRegressionJ();
    reg.setEliminateColinearAttributes(false);
    reg.setAttributeSelectionMethod(
            new SelectedTag(LinearRegressionJ.SELECTION_NONE, LinearRegressionJ.TAGS_SELECTION));

    GPD gp = new GPD();
    gp.setNoise(.01);
    //RBFKernel rbf = new RBFKernel();
    //rbf.setChecksTurnedOff(true);
    //rbf.setGamma(.01);
    //gp.setKernel(rbf);

    Remove remove = new Remove();
    remove.setAttributeIndices("1");
    FilteredClassifier fc = new FilteredClassifier();

    MultiFilter mf = new MultiFilter();
    Filter[] filters = new Filter[2];
    filters[0] = remove;
    filters[1] = pf;
    mf.setFilters(filters);

    fc.setClassifier(gp);
    fc.setFilter(pf);

    as.setClassifier(gp);
    as.setClassIndex("last");
    //as.setDataset(new PlaceholderFile("/home/dale/OMD_clean.arff"));
    //as.setOutputDirectory(new PlaceholderFile("/research/dale"));
    ga.setLoggingLevel(LoggingLevel.INFO);
    as.setLoggingLevel(LoggingLevel.INFO);
    ga.optimise(as.getDataDef(), as);

}

From source file:jjj.asap.sas.models1.job.BuildPLSModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }/* ww w  . ja v a 2s. co m*/
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // Standard PLS

    PLSClassifier pls = new PLSClassifier();
    PLSFilter filter = (PLSFilter) pls.getFilter();
    filter.setNumComponents(5);
    filter.setPreprocessing(NONE);

    // centered PLS

    PLSClassifier plsc = new PLSClassifier();
    PLSFilter center = (PLSFilter) plsc.getFilter();
    center.setNumComponents(5);
    center.setPreprocessing(CENTER);

    // standardized PLS

    PLSClassifier plss = new PLSClassifier();
    PLSFilter std = (PLSFilter) plss.getFilter();
    std.setNumComponents(10);
    std.setPreprocessing(STANDARDIZE);

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        int essaySet = Contest.getEssaySet(dsn);

        Classifier alg = pls;

        if (essaySet == 10 || dsn.contains("1grams-thru-3grams")) {
            alg = plsc;
        }

        if (essaySet == 7) {
            alg = plss;
        }

        queue.add(Job.submit(
                new RegressionModelBuilder(dsn, "PLS", AbstractClassifier.makeCopy(alg), this.outputBucket)));
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:jjj.asap.sas.models1.job.RGramModels.java

License:Open Source License

@Override
protected void run() throws Exception {

    // validate args
    if (!Bucket.isBucket("datasets", inputBucket)) {
        throw new FileNotFoundException(inputBucket);
    }/*  ww w.ja  v a2  s.  com*/
    if (!Bucket.isBucket("models", outputBucket)) {
        throw new FileNotFoundException(outputBucket);
    }

    // create prototype classifiers
    List<Classifier> models = new ArrayList<Classifier>();

    //SGD sgd = new SGD();
    //sgd.setDontNormalize(true);
    //sgd.setLossFunction(new SelectedTag(SGD.SQUAREDLOSS,SGD.TAGS_SELECTION));

    LinearRegression m5 = new LinearRegression();
    m5.setAttributeSelectionMethod(M5);

    //models.add(sgd);
    models.add(m5);

    LinearRegression lr = new LinearRegression();
    lr.setAttributeSelectionMethod(NONE);

    RandomSubSpace rss = new RandomSubSpace();
    rss.setClassifier(lr);
    rss.setNumIterations(30);

    models.add(rss);

    AdditiveRegression boostedStumps = new AdditiveRegression();
    boostedStumps.setClassifier(new DecisionStump());
    boostedStumps.setNumIterations(1000);

    AdditiveRegression boostedTrees = new AdditiveRegression();
    boostedTrees.setClassifier(new REPTree());
    boostedTrees.setNumIterations(100);

    models.add(boostedStumps);
    models.add(boostedTrees);

    models.add(new PLSClassifier());

    // init multi-threading
    Job.startService();
    final Queue<Future<Object>> queue = new LinkedList<Future<Object>>();

    // get the input from the bucket
    List<String> names = Bucket.getBucketItems("datasets", this.inputBucket);
    for (String dsn : names) {

        for (Classifier model : models) {

            String tag = null;
            if (model instanceof SingleClassifierEnhancer) {
                tag = model.getClass().getSimpleName() + "-"
                        + ((SingleClassifierEnhancer) model).getClassifier().getClass().getSimpleName();
            } else {
                tag = model.getClass().getSimpleName();
            }

            queue.add(Job.submit(new RegressionModelBuilder(dsn, tag, AbstractClassifier.makeCopy(model),
                    this.outputBucket)));
        }
    }

    // wait on complete
    Progress progress = new Progress(queue.size(), this.getClass().getSimpleName());
    while (!queue.isEmpty()) {
        try {
            queue.remove().get();
        } catch (Exception e) {
            Job.log("ERROR", e.toString());
        }
        progress.tick();
    }
    progress.done();
    Job.stopService();

}

From source file:org.jaqpot.algorithm.resource.WekaPLS.java

License:Open Source License

@POST
@Path("training")
public Response training(TrainingRequest request) {
    try {/*from w w  w .j ava2s . c  o m*/
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response.status(Response.Status.BAD_REQUEST).entity(
                    ErrorReportFactory.badRequest("Dataset is empty", "Cannot train model on empty dataset"))
                    .build();
        }
        List<String> features = request.getDataset().getDataEntry().stream().findFirst().get().getValues()
                .keySet().stream().collect(Collectors.toList());

        Instances data = InstanceUtils.createFromDataset(request.getDataset(), request.getPredictionFeature());
        Map<String, Object> parameters = request.getParameters() != null ? request.getParameters()
                : new HashMap<>();

        Integer components = Integer.parseInt(parameters.getOrDefault("components", _components).toString());
        String algorithm = parameters.getOrDefault("algorithm", _algorithm).toString();

        PLSClassifier classifier = new PLSClassifier();
        classifier.setOptions(new String[] { "-C", components.toString(), "-A", algorithm });
        classifier.buildClassifier(data);

        WekaModel model = new WekaModel();
        model.setClassifier(classifier);

        TrainingResponse response = new TrainingResponse();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ObjectOutput out = new ObjectOutputStream(baos);
        out.writeObject(model);
        String base64Model = Base64.getEncoder().encodeToString(baos.toByteArray());
        response.setRawModel(base64Model);
        List<String> independentFeatures = features.stream()
                .filter(feature -> !feature.equals(request.getPredictionFeature()))
                .collect(Collectors.toList());
        response.setIndependentFeatures(independentFeatures);
        //            response.setPmmlModel(pmml);
        response.setAdditionalInfo(request.getPredictionFeature());
        response.setPredictedFeatures(
                Arrays.asList("Weka PLS prediction of " + request.getPredictionFeature()));

        return Response.ok(response).build();
    } catch (Exception ex) {
        LOG.log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.jaqpot.algorithms.resource.WekaPLS.java

License:Open Source License

@POST
@Path("training")
public Response training(TrainingRequest request) {
    try {/*from   w  w  w. j ava 2s  . c  o  m*/
        if (request.getDataset().getDataEntry().isEmpty()
                || request.getDataset().getDataEntry().get(0).getValues().isEmpty()) {
            return Response.status(Response.Status.BAD_REQUEST)
                    .entity("Dataset is empty. Cannot train model on empty dataset.").build();
        }
        List<String> features = request.getDataset().getDataEntry().stream().findFirst().get().getValues()
                .keySet().stream().collect(Collectors.toList());

        Instances data = InstanceUtils.createFromDataset(request.getDataset(), request.getPredictionFeature());
        Map<String, Object> parameters = request.getParameters() != null ? request.getParameters()
                : new HashMap<>();

        Integer components = Integer.parseInt(parameters.getOrDefault("components", _components).toString());
        String algorithm = parameters.getOrDefault("algorithm", _algorithm).toString();

        PLSClassifier classifier = new PLSClassifier();
        classifier.setOptions(new String[] { "-C", components.toString(), "-A", algorithm });
        classifier.buildClassifier(data);

        WekaModel model = new WekaModel();
        model.setClassifier(classifier);

        TrainingResponse response = new TrainingResponse();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ObjectOutput out = new ObjectOutputStream(baos);
        out.writeObject(model);
        String base64Model = Base64.getEncoder().encodeToString(baos.toByteArray());
        response.setRawModel(base64Model);
        List<String> independentFeatures = features.stream()
                .filter(feature -> !feature.equals(request.getPredictionFeature()))
                .collect(Collectors.toList());
        response.setIndependentFeatures(independentFeatures);
        //            response.setPmmlModel(pmml);
        response.setAdditionalInfo(request.getPredictionFeature());
        response.setPredictedFeatures(
                Arrays.asList("Weka PLS prediction of " + request.getPredictionFeature()));

        return Response.ok(response).build();
    } catch (Exception ex) {
        LOG.log(Level.SEVERE, null, ex);
        return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(ex.getMessage()).build();
    }
}

From source file:org.opentox.jaqpot3.qsar.trainer.PLSTrainer.java

License:Open Source License

@Override
public Model train(Instances data) throws JaqpotException {
    Model model = new Model(Configuration.getBaseUri().augment("model", getUuid().toString()));

    data.setClass(data.attribute(targetUri.toString()));

    Boolean targetURIIncluded = false;
    for (Feature tempFeature : independentFeatures) {
        if (StringUtils.equals(tempFeature.getUri().toString(), targetUri.toString())) {
            targetURIIncluded = true;/*w  ww.j a  v a 2  s  .  com*/
            break;
        }
    }
    if (!targetURIIncluded) {
        independentFeatures.add(new Feature(targetUri));
    }
    model.setIndependentFeatures(independentFeatures);

    /*
     * Train the PLS filter
     */
    PLSFilter pls = new PLSFilter();
    try {
        pls.setInputFormat(data);
        pls.setOptions(new String[] { "-C", Integer.toString(numComponents), "-A", pls_algorithm, "-P",
                preprocessing, "-U", doUpdateClass });
        PLSFilter.useFilter(data, pls);
    } catch (Exception ex) {
        Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex);
    }

    PLSModel actualModel = new PLSModel(pls);
    try {

        PLSClassifier cls = new PLSClassifier();
        cls.setFilter(pls);
        cls.buildClassifier(data);

        // evaluate classifier and print some statistics
        Evaluation eval = new Evaluation(data);
        eval.evaluateModel(cls, data);
        String stats = eval.toSummaryString("", false);

        ActualModel am = new ActualModel(actualModel);
        am.setStatistics(stats);

        model.setActualModel(am);
    } catch (NotSerializableException ex) {
        Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex);
        throw new JaqpotException(ex);
    } catch (Exception ex) {
        Logger.getLogger(PLSTrainer.class.getName()).log(Level.SEVERE, null, ex);
        throw new JaqpotException(ex);
    }

    model.setDataset(datasetUri);
    model.setAlgorithm(Algorithms.plsFilter());
    model.getMeta().addTitle("PLS Model for " + datasetUri);

    Set<Parameter> parameters = new HashSet<Parameter>();
    Parameter targetPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "target", new LiteralValue(targetUri.toString(), XSDDatatype.XSDstring))
                    .setScope(Parameter.ParameterScope.MANDATORY);
    Parameter nComponentsPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "numComponents", new LiteralValue(numComponents, XSDDatatype.XSDpositiveInteger))
                    .setScope(Parameter.ParameterScope.MANDATORY);
    Parameter preprocessingPrm = new Parameter(
            Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()), "preprocessing",
            new LiteralValue(preprocessing, XSDDatatype.XSDstring)).setScope(Parameter.ParameterScope.OPTIONAL);
    Parameter algorithmPrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "algorithm", new LiteralValue(pls_algorithm, XSDDatatype.XSDstring))
                    .setScope(Parameter.ParameterScope.OPTIONAL);
    Parameter doUpdatePrm = new Parameter(Configuration.getBaseUri().augment("parameter", RANDOM.nextLong()),
            "doUpdateClass", new LiteralValue(doUpdateClass, XSDDatatype.XSDboolean))
                    .setScope(Parameter.ParameterScope.OPTIONAL);

    parameters.add(targetPrm);
    parameters.add(nComponentsPrm);
    parameters.add(preprocessingPrm);
    parameters.add(doUpdatePrm);
    parameters.add(algorithmPrm);
    model.setParameters(parameters);

    for (int i = 0; i < numComponents; i++) {
        Feature f = publishFeature(model, "", "PLS-" + i, datasetUri, featureService);
        model.addPredictedFeatures(f);
    }

    //save the instances being predicted to abstract trainer for calculating DoA
    predictedInstances = data;
    //in pls target is not excluded

    return model;
}