com.mapr.stats.bandit.ContextualBayesBandit.java Source code

Java tutorial

Introduction

Here is the source code for com.mapr.stats.bandit.ContextualBayesBandit.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.mapr.stats.bandit;

import com.mapr.stats.random.BetaDistribution;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.VectorFunction;

/**
 * Solves the contextual bandit problem using Bayesian sampling.
 */
public class ContextualBayesBandit {
    private final Matrix featureMap;
    private final Matrix state;
    private final int m;
    private final BetaDistribution rand;

    public ContextualBayesBandit(Matrix featureMap) {
        this(featureMap, 1, 1);
    }

    public ContextualBayesBandit(Matrix featureMap, double alpha_0, double beta_0) {
        this.featureMap = featureMap;
        m = featureMap.numCols();
        this.state = new DenseMatrix(m, 2);
        this.state.viewColumn(0).assign(alpha_0);
        this.state.viewColumn(1).assign(beta_0);
        this.rand = new BetaDistribution(1, 1);
    }

    public Vector samplePi() {
        return sampleNoLink().assign(new LogisticFunction());
    }

    public int sample() {
        final Vector pi = sampleNoLink();
        return pi.maxValueIndex();
    }

    private Vector sampleNoLink() {
        final Vector theta = state.aggregateRows(new VectorFunction() {
            final DoubleFunction inverseLink = new InverseLogisticFunction();

            @Override
            public double apply(Vector f) {
                return inverseLink.apply(rand.nextDouble(f.get(0), f.get(1)));
            }
        });
        return featureMap.times(theta);
    }

    public void train(int bandit, boolean success) {
        state.viewColumn(success ? 0 : 1).assign(featureMap.viewRow(bandit), Functions.plusMult(1.0 / m));
    }

    public class LogisticFunction implements DoubleFunction {
        @Override
        public double apply(double x) {
            return 1 / (1 + Math.exp(-x));
        }
    }

    public class InverseLogisticFunction implements DoubleFunction {
        @Override
        public double apply(double p) {
            return Math.log(p / (1 - p));
        }
    }
}