org.apache.mahout.classifier.sgd.CrossFoldLearner.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.mahout.classifier.sgd.CrossFoldLearner.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sgd;

import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.apache.mahout.math.stats.OnlineAuc;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * Does cross-fold validation of log-likelihood and AUC on several online logistic regression
 * models. Each record is passed to all but one of the models for training and to the remaining
 * model for evaluation.  In order to maintain proper segregation between the different folds across
 * training data iterations, data should either be passed to this learner in the same order each
 * time the training data is traversed or a tracking key such as the file offset of the training
 * record should be passed with each training example.
 */
public class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Writable {
    private int record;
    // minimum score to be used for computing log likelihood
    private static final double MIN_SCORE = 1.0e-50;
    private OnlineAuc auc = new GlobalOnlineAuc();
    private double logLikelihood;
    private final List<OnlineLogisticRegression> models = new ArrayList<>();

    // lambda, learningRate, perTermOffset, perTermExponent
    private double[] parameters = new double[4];
    private int numFeatures;
    private PriorFunction prior;
    private double percentCorrect;

    private int windowSize = Integer.MAX_VALUE;

    public CrossFoldLearner() {
    }

    public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) {
        this.numFeatures = numFeatures;
        this.prior = prior;
        for (int i = 0; i < folds; i++) {
            OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior);
            model.alpha(1).stepOffset(0).decayExponent(0);
            models.add(model);
        }
    }

    // -------- builder-like configuration methods

    public CrossFoldLearner lambda(double v) {
        for (OnlineLogisticRegression model : models) {
            model.lambda(v);
        }
        return this;
    }

    public CrossFoldLearner learningRate(double x) {
        for (OnlineLogisticRegression model : models) {
            model.learningRate(x);
        }
        return this;
    }

    public CrossFoldLearner stepOffset(int x) {
        for (OnlineLogisticRegression model : models) {
            model.stepOffset(x);
        }
        return this;
    }

    public CrossFoldLearner decayExponent(double x) {
        for (OnlineLogisticRegression model : models) {
            model.decayExponent(x);
        }
        return this;
    }

    public CrossFoldLearner alpha(double alpha) {
        for (OnlineLogisticRegression model : models) {
            model.alpha(alpha);
        }
        return this;
    }

    // -------- training methods
    @Override
    public void train(int actual, Vector instance) {
        train(record, null, actual, instance);
    }

    @Override
    public void train(long trackingKey, int actual, Vector instance) {
        train(trackingKey, null, actual, instance);
    }

    @Override
    public void train(long trackingKey, String groupKey, int actual, Vector instance) {
        record++;
        int k = 0;
        for (OnlineLogisticRegression model : models) {
            if (k == mod(trackingKey, models.size())) {
                Vector v = model.classifyFull(instance);
                double score = Math.max(v.get(actual), MIN_SCORE);
                logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize);

                int correct = v.maxValueIndex() == actual ? 1 : 0;
                percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize);
                if (numCategories() == 2) {
                    auc.addSample(actual, groupKey, v.get(1));
                }
            } else {
                model.train(trackingKey, groupKey, actual, instance);
            }
            k++;
        }
    }

    private static long mod(long x, int y) {
        long r = x % y;
        return r < 0 ? r + y : r;
    }

    @Override
    public void close() {
        for (OnlineLogisticRegression m : models) {
            m.close();
        }
    }

    public void resetLineCounter() {
        record = 0;
    }

    public boolean validModel() {
        boolean r = true;
        for (OnlineLogisticRegression model : models) {
            r &= model.validModel();
        }
        return r;
    }

    // -------- classification methods

    @Override
    public Vector classify(Vector instance) {
        Vector r = new DenseVector(numCategories() - 1);
        DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
        for (OnlineLogisticRegression model : models) {
            r.assign(model.classify(instance), scale);
        }
        return r;
    }

    @Override
    public Vector classifyNoLink(Vector instance) {
        Vector r = new DenseVector(numCategories() - 1);
        DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
        for (OnlineLogisticRegression model : models) {
            r.assign(model.classifyNoLink(instance), scale);
        }
        return r;
    }

    @Override
    public double classifyScalar(Vector instance) {
        double r = 0;
        int n = 0;
        for (OnlineLogisticRegression model : models) {
            n++;
            r += model.classifyScalar(instance);
        }
        return r / n;
    }

    // -------- status reporting methods

    @Override
    public int numCategories() {
        return models.get(0).numCategories();
    }

    public double auc() {
        return auc.auc();
    }

    public double logLikelihood() {
        return logLikelihood;
    }

    public double percentCorrect() {
        return percentCorrect;
    }

    // -------- evolutionary optimization

    public CrossFoldLearner copy() {
        CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior);
        r.models.clear();
        for (OnlineLogisticRegression model : models) {
            model.close();
            OnlineLogisticRegression newModel = new OnlineLogisticRegression(model.numCategories(),
                    model.numFeatures(), model.prior);
            newModel.copyFrom(model);
            r.models.add(newModel);
        }
        return r;
    }

    public int getRecord() {
        return record;
    }

    public void setRecord(int record) {
        this.record = record;
    }

    public OnlineAuc getAucEvaluator() {
        return auc;
    }

    public void setAucEvaluator(OnlineAuc auc) {
        this.auc = auc;
    }

    public double getLogLikelihood() {
        return logLikelihood;
    }

    public void setLogLikelihood(double logLikelihood) {
        this.logLikelihood = logLikelihood;
    }

    public List<OnlineLogisticRegression> getModels() {
        return models;
    }

    public void addModel(OnlineLogisticRegression model) {
        models.add(model);
    }

    public double[] getParameters() {
        return parameters;
    }

    public void setParameters(double[] parameters) {
        this.parameters = parameters;
    }

    public int getNumFeatures() {
        return numFeatures;
    }

    public void setNumFeatures(int numFeatures) {
        this.numFeatures = numFeatures;
    }

    public void setWindowSize(int windowSize) {
        this.windowSize = windowSize;
        auc.setWindowSize(windowSize);
    }

    public PriorFunction getPrior() {
        return prior;
    }

    public void setPrior(PriorFunction prior) {
        this.prior = prior;
    }

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeInt(record);
        PolymorphicWritable.write(out, auc);
        out.writeDouble(logLikelihood);
        out.writeInt(models.size());
        for (OnlineLogisticRegression model : models) {
            model.write(out);
        }

        for (double x : parameters) {
            out.writeDouble(x);
        }
        out.writeInt(numFeatures);
        PolymorphicWritable.write(out, prior);
        out.writeDouble(percentCorrect);
        out.writeInt(windowSize);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        record = in.readInt();
        auc = PolymorphicWritable.read(in, OnlineAuc.class);
        logLikelihood = in.readDouble();
        int n = in.readInt();
        for (int i = 0; i < n; i++) {
            OnlineLogisticRegression olr = new OnlineLogisticRegression();
            olr.readFields(in);
            models.add(olr);
        }
        parameters = new double[4];
        for (int i = 0; i < 4; i++) {
            parameters[i] = in.readDouble();
        }
        numFeatures = in.readInt();
        prior = PolymorphicWritable.read(in, PriorFunction.class);
        percentCorrect = in.readDouble();
        windowSize = in.readInt();
    }
}