edu.oregonstate.eecs.mcplan.rl.QLearner.java Source code

Java tutorial

Introduction

Here is the source code for edu.oregonstate.eecs.mcplan.rl.QLearner.java

Source

/* LICENSE
Copyright (c) 2013-2016, Jesse Hostetler (jessehostetler@gmail.com)
All rights reserved.
    
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
    
1. Redistributions of source code must retain the above copyright notice,
   this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.
    
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/**
 * 
 */
package edu.oregonstate.eecs.mcplan.rl;

import java.util.HashMap;
import java.util.Map;

import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;

import edu.oregonstate.eecs.mcplan.ActionGenerator;
import edu.oregonstate.eecs.mcplan.JointPolicy;
import edu.oregonstate.eecs.mcplan.Policy;
import edu.oregonstate.eecs.mcplan.Representation;
import edu.oregonstate.eecs.mcplan.Representer;
import edu.oregonstate.eecs.mcplan.State;
import edu.oregonstate.eecs.mcplan.VirtualConstructor;
import edu.oregonstate.eecs.mcplan.domains.taxi.PrimitiveTaxiRepresentation;
import edu.oregonstate.eecs.mcplan.domains.taxi.PrimitiveTaxiRepresenter;
import edu.oregonstate.eecs.mcplan.domains.taxi.TaxiAction;
import edu.oregonstate.eecs.mcplan.domains.taxi.TaxiActionGenerator;
import edu.oregonstate.eecs.mcplan.domains.taxi.TaxiSimulator;
import edu.oregonstate.eecs.mcplan.domains.taxi.TaxiState;
import edu.oregonstate.eecs.mcplan.domains.taxi.TaxiWorlds;
import edu.oregonstate.eecs.mcplan.sim.AverageRewardAccumulator;
import edu.oregonstate.eecs.mcplan.sim.Episode;
import edu.oregonstate.eecs.mcplan.util.Fn;
import gnu.trove.iterator.TObjectDoubleIterator;
import gnu.trove.map.TObjectDoubleMap;
import gnu.trove.map.hash.TObjectDoubleHashMap;

/**
 * A basic Q-learning implementation using constant learning rate and
 * epsilon-greedy exploration.
 * 
 * TODO: Generalize
 * 
 * @author jhostetler
 */
public class QLearner<S extends State, X extends Representation<S>, A extends VirtualConstructor<A>>
        extends Policy<S, A> {
    public final Map<X, TObjectDoubleMap<A>> values = new HashMap<X, TObjectDoubleMap<A>>();

    private final RandomGenerator rng_;
    private final ActionGenerator<S, A> action_gen_;

    public final Representer<S, X> repr;
    public final double gamma;
    public final double Vmax;
    public final double epsilon;
    public final double alpha;

    private S s_ = null;
    private X x_ = null;
    private A a_ = null;
    private TObjectDoubleMap<A> Qfunction_ = null;

    private long t_ = 0L;
    private final int[] turn_;

    public QLearner(final int[] turn, final RandomGenerator rng, final Representer<S, X> repr,
            final ActionGenerator<S, A> action_gen, final double gamma, final double Vmax, final double epsilon,
            final double alpha) {
        turn_ = turn;
        rng_ = rng;
        action_gen_ = action_gen;

        this.repr = repr;
        this.gamma = gamma;
        this.Vmax = Vmax;
        this.epsilon = epsilon;
        this.alpha = alpha;
    }

    private TObjectDoubleMap<A> getQFunction(final S s, final X x, final long t, final int[] turn) {
        TObjectDoubleMap<A> Qfunction = values.get(x);
        if (Qfunction == null) {
            Qfunction = new TObjectDoubleHashMap<A>();
            action_gen_.setState(s, t);
            while (action_gen_.hasNext()) {
                Qfunction.put(action_gen_.next(), Vmax);
            }
            values.put(x, Qfunction);
        }
        return Qfunction;
    }

    private A maxA(final TObjectDoubleMap<A> Qfunction) {
        double best = -Double.MAX_VALUE;
        A best_a = null;
        final TObjectDoubleIterator<A> itr = Qfunction.iterator();
        while (itr.hasNext()) {
            itr.advance();
            final double d = itr.value();
            if (d > best) {
                best = d;
                best_a = itr.key();
            }
        }
        return best_a;
    }

    private double maxQ(final TObjectDoubleMap<A> Qfunction) {
        double best = -Double.MAX_VALUE;
        final TObjectDoubleIterator<A> itr = Qfunction.iterator();
        while (itr.hasNext()) {
            itr.advance();
            final double d = itr.value();
            if (d > best) {
                best = d;
            }
        }
        return best;
    }

    private double learningRate(final S s, final A a) {
        return alpha;
    }

    @Override
    public void setState(final S s, final long t) {
        s_ = s;
        t_ = t;
        x_ = repr.encode(s);
        Qfunction_ = getQFunction(s_, x_, t_, turn_);
    }

    @Override
    public A getAction() {
        final double r = rng_.nextDouble();

        if (r < epsilon) {
            action_gen_.setState(s_, t_);
            a_ = Fn.uniform_choice(rng_, action_gen_);
        } else {
            a_ = maxA(Qfunction_).create();
        }

        return a_;
    }

    @Override
    public void actionResult(final S sprime, final double[] r) {
        final X xprime = repr.encode(sprime);
        // FIXME: We can't know 'tprime' in general since we don't know how
        // many other agents there are. As of [2014/06/23], we have no
        // nonstationary domains, so it doesn't matter.
        final TObjectDoubleMap<A> Qfunction_prime = getQFunction(sprime, xprime, 0xDEADBEEF, turn_);

        final double Qa = Qfunction_.get(a_);
        final double max_q = maxQ(Qfunction_prime);
        final double err = r[0] + gamma * max_q - Qa;
        final double old = Qfunction_.get(a_);
        Qfunction_.put(a_, old + learningRate(s_, a_) * err);
    }

    @Override
    public String getName() {
        return "QLearning";
    }

    @Override
    public int hashCode() {
        return QLearner.class.hashCode();
    }

    @Override
    public boolean equals(final Object obj) {
        return obj instanceof QLearner;
    }

    // -----------------------------------------------------------------------

    public static void main(final String[] argv) {
        final RandomGenerator rng = new MersenneTwister(43);

        final int Nother_taxis = 0;
        final double slip = 0.0;
        final TaxiState state_prototype = TaxiWorlds.dietterich2000(rng, Nother_taxis, slip);

        final int T = 100000;

        final double gamma = 0.9;
        final double Vmax = 20.0;
        final double epsilon = 0.1;
        final double alpha = 0.1;
        final QLearner<TaxiState, PrimitiveTaxiRepresentation, TaxiAction> learner = new QLearner<TaxiState, PrimitiveTaxiRepresentation, TaxiAction>(
                new int[] { 0 }, rng, new PrimitiveTaxiRepresenter(state_prototype), new TaxiActionGenerator(),
                gamma, Vmax, epsilon, alpha);

        //      final int scale = 20;
        //      final TaxiVisualization vis = new TaxiVisualization( null, state_prototype.topology, state_prototype.locations, scale );
        //      final EpisodeListener<TaxiState, TaxiAction> updater = vis.updater( 0 );

        final AverageRewardAccumulator<TaxiState, TaxiAction> avg = new AverageRewardAccumulator<TaxiState, TaxiAction>(
                1);
        final double lag = -Double.MAX_VALUE;

        final Map<PrimitiveTaxiRepresentation, TObjectDoubleMap<TaxiAction>> old_values = new HashMap<PrimitiveTaxiRepresentation, TObjectDoubleMap<TaxiAction>>();

        int ns = 500;
        for (int i = 0; i < Nother_taxis; ++i) {
            ns *= 25 - i - 1;
        }
        final int Nstates = ns;

        int count = 0;
        while (true) {
            final TaxiState state = TaxiWorlds.dietterich2000(rng, Nother_taxis, slip);
            final TaxiSimulator sim = new TaxiSimulator(rng, state, slip, T);
            final Episode<TaxiState, TaxiAction> episode = new Episode<TaxiState, TaxiAction>(sim,
                    JointPolicy.create(learner), T);
            episode.addListener(avg);
            //         episode.addListener( updater );
            //      episode.addListener( new LoggingEpisodeListener<TaxiState, TaxiAction>() );
            episode.run();

            //         final double diff = Math.abs( avg.reward[0].mean() - lag );

            //         System.out.println( "Episode " + count + ": avg reward = " + avg.reward[0].mean() );

            count += 1;

            if ((count % 10000 == 0) && learner.values.size() == Nstates) {
                //            System.out.println( "learner.values.size() == " + Nstates );
                boolean complete = true;
                double norm = 0.0;
                for (final Map.Entry<PrimitiveTaxiRepresentation, TObjectDoubleMap<TaxiAction>> e : learner.values
                        .entrySet()) {
                    final TObjectDoubleMap<TaxiAction> new_q = e.getValue();
                    TObjectDoubleMap<TaxiAction> old_q = old_values.get(e.getKey());
                    if (old_q == null) {
                        old_q = new TObjectDoubleHashMap<TaxiAction>();
                        old_values.put(e.getKey(), old_q);
                        complete = false;
                    }
                    final TObjectDoubleIterator<TaxiAction> itr = new_q.iterator();
                    while (itr.hasNext()) {
                        itr.advance();
                        final TaxiAction a = itr.key();
                        final double new_qa = itr.value();
                        final double old_qa = old_q.get(a);
                        final double diff = new_qa - old_qa;
                        norm += diff * diff;
                        old_q.put(a, new_qa);
                    }
                }
                System.out.println("Qnorm = " + norm);
                if (complete && norm < 1e-6) {
                    break;
                }
            }
        }

    }
}