com.insightml.evaluation.simulation.CrossValidation.java Source code

Java tutorial

Introduction

Here is the source code for com.insightml.evaluation.simulation.CrossValidation.java

Source

/*
 * Copyright (C) 2016 Stefan Hen
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.insightml.evaluation.simulation;

import java.util.LinkedList;
import java.util.List;

import org.apache.commons.math3.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.insightml.data.FeaturesConfig;
import com.insightml.data.samples.Sample;
import com.insightml.data.utils.InstancesFilter;
import com.insightml.models.ILearnerPipeline;
import com.insightml.models.ModelPipeline;
import com.insightml.models.Predictions;
import com.insightml.utils.Check;
import com.insightml.utils.jobs.AbstractJob;
import com.insightml.utils.jobs.IJobBatch;

public final class CrossValidation<I extends Sample> extends AbstractSimulation<I> {
    private static final long serialVersionUID = 2206245632537034091L;

    private final int folds;
    private final int repetitions;

    public CrossValidation(final int folds, final int repetitions, final SimulationResultConsumer database) {
        super((repetitions == 1 ? "" : repetitions + "x ") + folds + "-Fold Cross-Validation", database);
        this.folds = Check.num(folds, 2, 20);
        this.repetitions = Check.num(repetitions, 1, 10);
    }

    @Override
    public String getDescription() {
        return (repetitions > 1 ? repetitions + " repetitions of " : "") + folds + "-fold cross-validation";
    }

    @Override
    public <E, P> SimulationResults<E, P>[] run(final Iterable<I> instances, final SimulationSetup<I, E, P> setup) {
        SimulationResults<E, P>[] lastResult = null;
        for (final int foldss : folds == -1 ? new int[] { 2, 5, 10 } : new int[] { folds }) {
            lastResult = runCv(instances, setup);
            for (final SimulationResults<E, P> element : lastResult) {
                notify(element, setup, foldss * repetitions + "");
            }
        }
        return lastResult;
    }

    private <E, P> SimulationResults<E, P>[] runCv(final Iterable<I> instances,
            final SimulationSetup<I, E, P> setup) {
        final ILearnerPipeline<I, P>[] learner = setup.getLearner();
        final IJobBatch<Predictions<E, P>[]> batch = ((SimulationSetupImpl) setup).createBatch();
        final int numLabels = Check.num(instances.iterator().next().getExpected().length, 1, 10);
        final Integer labelIndex = ((SimulationSetupImpl) setup).getLabelIndex();
        final FeaturesConfig<I, P> config = setup.getConfig();
        for (int repetition = 0; repetition < repetitions; ++repetition) {
            // repetition == 0 ? instances : instances.randomize(random);
            final Iterable<I> shuffled = instances;
            for (int fold = 1; fold <= folds; ++fold) {
                final int actualFold = fold + repetition * folds;
                if (labelIndex == null) {
                    for (int index = 0; index < numLabels; ++index) {
                        batch.addJob(new Fold<I, E, P>(config, shuffled, learner, fold, actualFold, folds, index));
                    }
                } else {
                    batch.addJob(new Fold<I, E, P>(config, shuffled, learner, fold, actualFold, folds, labelIndex));
                }
            }
        }
        final SimulationResultsBuilder<E, P>[] builders = new SimulationResultsBuilder[learner.length];
        for (int i = 0; i < learner.length; ++i) {
            builders[i] = new SimulationResultsBuilder<>(learner[i].getName(), folds * repetitions, numLabels,
                    setup);
        }
        for (final Predictions<E, P>[] preds : batch.run()) {
            for (int i = 0; i < preds.length; ++i) {
                builders[i].add(preds[i]);
            }
        }
        final SimulationResults<E, P>[] results = new SimulationResults[learner.length];
        for (int i = 0; i < results.length; ++i) {
            results[i] = builders[i].build();
        }
        return results;
    }

    private static final class Fold<I extends Sample, E, P> extends AbstractJob<Predictions<E, P>[]> {
        private static final long serialVersionUID = 8592592353685668153L;

        private final int fold;
        private final int actualFold;
        private final int folds;
        private final int label;

        private final FeaturesConfig<I, P> config;
        private final Iterable<I> labled;
        private final ILearnerPipeline<I, P>[] learner;
        private final Logger logger = LoggerFactory.getLogger(Fold.class);

        Fold(final FeaturesConfig<I, P> config, final Iterable<I> shuffled, final ILearnerPipeline<I, P>[] loader,
                final int fold, final int actualFold, final int folds, final int label) {
            super("CrossValidation Fold #" + actualFold);
            this.fold = fold;
            this.actualFold = actualFold;
            this.folds = folds;
            this.label = label;

            this.config = config;
            labled = InstancesFilter.hasLabelSet(shuffled, label);
            learner = loader;
        }

        @Override
        public Predictions<E, P>[] run() {
            final Pair<List<I>, List<I>> sets = partition();
            final Predictions<E, P>[] preds = new Predictions[learner.length];
            for (int i = 0; i < preds.length; ++i) {
                final long start = System.currentTimeMillis();
                final ModelPipeline<I, P> model = learner[i].run(sets.getFirst(), sets.getSecond(), config, label);
                preds[i] = Predictions.create(actualFold, model, sets.getSecond(),
                        (int) (System.currentTimeMillis() - start));
                logger.info("Completed {} on {} train and {} test samples", getTitle(), sets.getFirst().size(),
                        sets.getSecond().size());
                if (false) {
                    logger.info(model.info());
                }
            }
            return preds;
        }

        private Pair<List<I>, List<I>> partition() {
            Check.num(fold, 1, folds);
            final List<I> train = new LinkedList<>();
            final List<I> test = new LinkedList<>();
            int i = -1;
            for (final I lab : labled) {
                final int bucket = ++i % folds;
                if (bucket == fold - 1) {
                    test.add(lab);
                } else {
                    train.add(lab);
                }
            }
            return new Pair<>(train, test);
        }
    }

}