edu.oregonstate.eecs.mcplan.domains.toy.CliffWorld.java Source code

Java tutorial

Introduction

Here is the source code for edu.oregonstate.eecs.mcplan.domains.toy.CliffWorld.java

Source

/* LICENSE
Copyright (c) 2013-2016, Jesse Hostetler (jessehostetler@gmail.com)
All rights reserved.
    
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
    
1. Redistributions of source code must retain the above copyright notice,
   this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.
    
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/**
 * 
 */
package edu.oregonstate.eecs.mcplan.domains.toy;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;

import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;

import weka.core.Attribute;
import edu.oregonstate.eecs.mcplan.ActionGenerator;
import edu.oregonstate.eecs.mcplan.FactoredRepresentation;
import edu.oregonstate.eecs.mcplan.FactoredRepresenter;
import edu.oregonstate.eecs.mcplan.JointAction;
import edu.oregonstate.eecs.mcplan.Representation;
import edu.oregonstate.eecs.mcplan.Representer;
import edu.oregonstate.eecs.mcplan.UndoableAction;
import edu.oregonstate.eecs.mcplan.VirtualConstructor;
import edu.oregonstate.eecs.mcplan.abstraction.IndexRepresentation;
import edu.oregonstate.eecs.mcplan.sim.UndoSimulator;
import edu.oregonstate.eecs.mcplan.util.Fn;
import edu.oregonstate.eecs.mcplan.util.KeyValueStore;

/**
 * @author jhostetler
 *
 */
public class CliffWorld {
    public static final double fall_penalty = -30;

    public static enum Path {
        Start, Road, Cliff, Dead
    }

    public static class State implements edu.oregonstate.eecs.mcplan.State {
        public final RandomGenerator rng;
        public final int L;
        public final int W;
        public final int F;

        public Path path = Path.Start;
        public int location = 0;
        public int wind = 0;
        public int slip = 0;

        public State(final RandomGenerator rng, final int L, final int W, final int F) {
            this.rng = rng;
            this.L = L;
            this.W = W;
            this.F = F;
        }

        public State(final State that) {
            this.rng = that.rng;
            this.L = that.L;
            this.W = that.W;
            this.F = that.F;

            this.path = that.path;
            this.location = that.location;
            this.wind = that.wind;
            this.slip = that.slip;
        }

        @Override
        public boolean isTerminal() {
            switch (path) {
            case Start:
                return false;
            case Road:
                return location == 3 * L;
            case Cliff:
                return location == L;
            case Dead:
                return true;
            default:
                throw new AssertionError("unreachable");
            }
        }

        @Override
        public String toString() {
            return "path: " + path + ", location: " + location + ", wind: " + wind + ", slip: " + slip;
        }

        @Override
        public void close() {
        }
    }

    // -----------------------------------------------------------------------

    public static abstract class Action extends UndoableAction<State> implements VirtualConstructor<Action> {
        public abstract double reward();
    }

    public static class TakeRoadAction extends Action {
        public static final double reward = 0;

        private boolean done = false;

        @Override
        public void undoAction(final State s) {
            assert (done);
            assert (s.path == Path.Road);
            s.path = Path.Start;
            done = false;
        }

        @Override
        public void doAction(final RandomGenerator rng, final State s) {
            assert (!done);
            assert (s.path == Path.Start);
            s.path = Path.Road;
            done = true;
        }

        @Override
        public boolean isDone() {
            return done;
        }

        @Override
        public TakeRoadAction create() {
            return new TakeRoadAction();
        }

        @Override
        public boolean equals(final Object obj) {
            return obj instanceof TakeRoadAction;
        }

        @Override
        public int hashCode() {
            return TakeRoadAction.class.hashCode();
        }

        @Override
        public String toString() {
            return "TakeRoadAction";
        }

        @Override
        public double reward() {
            return 0;
        }
    }

    public static class TakeCliffAction extends Action {
        public static final double reward = 0;

        private boolean done = false;

        @Override
        public void undoAction(final State s) {
            assert (done);
            assert (s.path == Path.Cliff);
            s.path = Path.Start;
            done = false;
        }

        @Override
        public void doAction(final RandomGenerator rng, final State s) {
            assert (!done);
            assert (s.path == Path.Start);
            s.path = Path.Cliff;
            done = true;
        }

        @Override
        public boolean isDone() {
            return done;
        }

        @Override
        public TakeCliffAction create() {
            return new TakeCliffAction();
        }

        @Override
        public boolean equals(final Object obj) {
            return obj instanceof TakeCliffAction;
        }

        @Override
        public int hashCode() {
            return TakeCliffAction.class.hashCode();
        }

        @Override
        public String toString() {
            return "TakeCliffAction";
        }

        @Override
        public double reward() {
            return 0;
        }
    }

    public static class CautiousAction extends Action {
        public static final double reward = -2;

        private boolean done = false;
        private int old_location = -1;
        private int old_slip = 0;

        @Override
        public void undoAction(final State s) {
            assert (done);
            assert (s.location > 0);
            s.location = old_location;
            s.slip = old_slip;
            done = false;
        }

        @Override
        public void doAction(final RandomGenerator rng, final State s) {
            assert (!done);
            old_location = s.location;
            old_slip = s.slip;

            if (s.slip == 0) {
                s.location += 1;
            } else {
                s.slip += 2;
            }

            done = true;
        }

        @Override
        public boolean isDone() {
            return done;
        }

        @Override
        public CautiousAction create() {
            return new CautiousAction();
        }

        @Override
        public boolean equals(final Object obj) {
            return obj instanceof CautiousAction;
        }

        @Override
        public int hashCode() {
            return CautiousAction.class.hashCode();
        }

        @Override
        public String toString() {
            return "CautiousAction";
        }

        @Override
        public double reward() {
            return CautiousAction.reward;
        }
    }

    public static class FastAction extends Action {
        public static final double reward = -1;

        private boolean done = false;
        private int old_location = -1;
        private int old_slip = 0;

        @Override
        public void undoAction(final State s) {
            assert (done);
            assert (s.location > 0 || s.path == Path.Dead);
            s.location = old_location;
            s.slip = old_slip;
            done = false;
        }

        @Override
        public void doAction(final RandomGenerator rng, final State s) {
            assert (!done);
            old_location = s.location;
            old_slip = s.slip;

            if (s.slip == 0) {
                s.location += 1;

                if (s.path == Path.Cliff) {
                    final double Pslip = s.wind / ((double) s.W);
                    if (rng.nextDouble() < Pslip) {
                        s.slip += 1;
                    }
                }
            } else {
                s.slip += 2;
            }

            done = true;
        }

        @Override
        public boolean isDone() {
            return done;
        }

        @Override
        public FastAction create() {
            return new FastAction();
        }

        @Override
        public boolean equals(final Object obj) {
            return obj instanceof FastAction;
        }

        @Override
        public int hashCode() {
            return FastAction.class.hashCode();
        }

        @Override
        public String toString() {
            return "FastAction";
        }

        @Override
        public double reward() {
            return FastAction.reward;
        }
    }

    public static class SteadyAction extends Action {
        public static final double reward = -2.5;

        private boolean done = false;
        private int old_slip = 0;

        @Override
        public void undoAction(final State s) {
            assert (done);
            s.slip = old_slip;
            done = false;
        }

        @Override
        public void doAction(final RandomGenerator rng, final State s) {
            assert (!done);
            old_slip = s.slip;
            if (s.slip > 0) {
                s.slip -= 1;
            }
            if (s.slip > 0) {
                s.slip += 2;
            }
            done = true;
        }

        @Override
        public boolean isDone() {
            return done;
        }

        @Override
        public SteadyAction create() {
            return new SteadyAction();
        }

        @Override
        public boolean equals(final Object obj) {
            return obj instanceof SteadyAction;
        }

        @Override
        public int hashCode() {
            return SteadyAction.class.hashCode();
        }

        @Override
        public String toString() {
            return "SteadyAction";
        }

        @Override
        public double reward() {
            return SteadyAction.reward;
        }
    }

    public static class Actions extends ActionGenerator<State, Action> {
        public static int actionSetIndex(final State s) {
            return s.path.ordinal();
        }

        private Path path = Path.Dead;
        private int n = 0;

        private final RandomGenerator rng;
        private final int Nactions = 3;

        public Actions(final RandomGenerator rng) {
            this.rng = rng;
        }

        @Override
        public Actions create() {
            return new Actions(rng);
        }

        @Override
        public void setState(final State s, final long t) {
            path = s.path;
            n = 0;
        }

        @Override
        public int size() {
            switch (path) {
            case Dead:
                return 0;
            case Start:
                return 2;
            case Road:
                return 2;
            default:
                return Nactions;
            }
        }

        @Override
        public boolean hasNext() {
            return n < size();
        }

        @Override
        public Action next() {
            assert (hasNext());

            final Action a;
            if (path == Path.Start) {
                if (n == 0) {
                    a = new TakeRoadAction();
                } else {
                    a = new TakeCliffAction();
                }
            } else {
                if (n == 0) {
                    a = new FastAction();
                } else if (n == 1) {
                    a = new CautiousAction();
                } else {
                    assert (path == Path.Cliff);
                    a = new SteadyAction();
                }
            }

            n += 1;
            return a;
        }
    }

    private static class PostDynamicsAction extends Action {
        private boolean done = false;
        private int old_wind = -1;
        private Path old_path = Path.Start;

        private final int new_wind;

        public PostDynamicsAction(final int new_wind) {
            this.new_wind = new_wind;
        }

        @Override
        public void undoAction(final State s) {
            assert (done);
            s.wind = old_wind;
            s.path = old_path;
            done = false;
        }

        @Override
        public void doAction(final RandomGenerator rng, final State s) {
            assert (!done);
            old_wind = s.wind;
            old_path = s.path;
            s.wind = new_wind;
            if (s.slip >= s.F) {
                s.path = Path.Dead;
            }
            done = true;
        }

        @Override
        public boolean isDone() {
            return done;
        }

        @Override
        public Action create() {
            return new PostDynamicsAction(new_wind);
        }

        @Override
        public double reward() {
            return 0;
        }
    }

    // -----------------------------------------------------------------------

    public static class Simulator implements UndoSimulator<State, Action> {
        private final State s;

        private final RandomGenerator rng;
        private final Deque<Action> action_history = new ArrayDeque<Action>();
        private final Deque<Action> postdynamics_history = new ArrayDeque<Action>();

        public Simulator(final State s, final RandomGenerator rng) {
            this.s = s;
            this.rng = rng;
        }

        @Override
        public State state() {
            return s;
        }

        @Override
        public void takeAction(final JointAction<Action> a) {
            final Action ai = a.get(0);
            ai.doAction(s);
            action_history.push(ai);

            final PostDynamicsAction post = new PostDynamicsAction(rng.nextInt(s.W));
            post.doAction(s);
            postdynamics_history.push(post);
        }

        @Override
        public long depth() {
            return action_history.size() + postdynamics_history.size();
        }

        @Override
        public long t() {
            return action_history.size();
        }

        @Override
        public int nagents() {
            return 1;
        }

        @Override
        public int[] turn() {
            return new int[] { 0 };
        }

        @Override
        public double[] reward() {
            //         if( s.isTerminal() ) {
            //            if( s.path == Path.Dead ) {
            //               return new double[] { -100 };
            //            }
            //            else {
            //               return new double[] { 0 };
            //            }
            //         }

            if (s.path == Path.Dead) {
                return new double[] { -30 };
            }

            final Action a = action_history.peek();
            if (a instanceof CautiousAction) {
                return new double[] { -2 };
            } else if (a instanceof SteadyAction) {
                return new double[] { -2.5 };
            } else {
                return new double[] { -1 };
            }
        }

        @Override
        public boolean isTerminalState() {
            return s.isTerminal();
        }

        @Override
        public long horizon() {
            return Long.MAX_VALUE;
        }

        @Override
        public String detailString() {
            return "CliffWorldSimulator";
        }

        @Override
        public void untakeLastAction() {
            final Action post = postdynamics_history.pop();
            post.undoAction(s);

            final Action a = action_history.pop();
            a.undoAction(s);
        }
    }

    // -----------------------------------------------------------------------

    public static class PrimitiveRepresentation extends FactoredRepresentation<State> {
        private final float[] phi;

        public PrimitiveRepresentation(final State s) {
            phi = new float[4];
            int idx = 0;
            phi[idx++] = s.path.ordinal();
            phi[idx++] = s.location;
            phi[idx++] = s.wind;
            phi[idx++] = s.slip;
        }

        private PrimitiveRepresentation(final float[] phi) {
            this.phi = phi;
        }

        @Override
        public float[] phi() {
            return phi;
        }

        @Override
        public FactoredRepresentation<State> copy() {
            return new PrimitiveRepresentation(phi);
        }

        @Override
        public boolean equals(final Object obj) {
            if (!(obj instanceof PrimitiveRepresentation)) {
                return false;
            }
            final PrimitiveRepresentation that = (PrimitiveRepresentation) obj;
            return Arrays.equals(phi, that.phi);
        }

        @Override
        public int hashCode() {
            return Arrays.hashCode(phi);
        }

        @Override
        public String toString() {
            return Arrays.toString(phi);
        }
    }

    public static class PrimitiveRepresenter implements FactoredRepresenter<State, FactoredRepresentation<State>> {
        private static ArrayList<Attribute> attributes;
        static {
            attributes = new ArrayList<Attribute>();
            attributes.add(new Attribute("path"));
            attributes.add(new Attribute("location"));
            attributes.add(new Attribute("wind"));
            attributes.add(new Attribute("slip"));
        }

        @Override
        public FactoredRepresenter<State, FactoredRepresentation<State>> create() {
            return new PrimitiveRepresenter();
        }

        @Override
        public PrimitiveRepresentation encode(final State s) {
            return new PrimitiveRepresentation(s);
        }

        @Override
        public ArrayList<Attribute> attributes() {
            return attributes;
        }
    }

    public static class ActionSetRepresenter implements Representer<State, Representation<State>> {
        @Override
        public Representer<State, Representation<State>> create() {
            return new ActionSetRepresenter();
        }

        @Override
        public Representation<State> encode(final State s) {
            return new IndexRepresentation<State>(Actions.actionSetIndex(s));
        }
    }

    // -----------------------------------------------------------------------

    public static class FsssModel extends edu.oregonstate.eecs.mcplan.search.fsss.FsssModel<State, Action> {
        private final KeyValueStore config;

        private final double Vmin;
        private final double Vmax;

        private final PrimitiveRepresenter base_repr = new PrimitiveRepresenter();
        private final ActionSetRepresenter action_repr = new ActionSetRepresenter();

        private final State s0;

        private int sample_count = 0;

        public FsssModel(final RandomGenerator rng, final KeyValueStore config) {
            this.config = config;
            final int L = config.getInt("cliffworld.L");
            final int W = config.getInt("cliffworld.W");
            final int F = config.getInt("cliffworld.F");
            s0 = new State(rng, L, W, F);
            Vmin = calculateVmin(s0);
            Vmax = calculateVmax(s0);
        }

        @Override
        public edu.oregonstate.eecs.mcplan.search.fsss.FsssModel<State, Action> create(final RandomGenerator rng) {
            return new FsssModel(rng, this.config);
        }

        private double calculateVmin(final State s) {
            // Worst possible plan is either:
            // 1. Do Cautious on the Safe road
            // 2. Do Cautious until the last step on the Risky road, then do
            // Fast, slip, and fail to recover, leading to a fall.
            return Math.min(3 * s.L * CautiousAction.reward, (s.L - 1) * CautiousAction.reward + FastAction.reward
                    + s.F * SteadyAction.reward + fall_penalty);
        }

        private double calculateVmax(final State s) {
            //         return s.L*FastAction.reward;
            return 0;
        }

        @Override
        public double Vmin(final State s) {
            return SteadyAction.reward * (s.L - s.location);
        }

        @Override
        public double Vmax(final State s) {
            return 0;
        }

        @Override
        public double Vmin(final State s, final Action a) {
            return reward(s, a) + SteadyAction.reward * (s.L - s.location - 1);
        }

        @Override
        public double Vmax(final State s, final Action a) {
            return reward(s, a);
        }

        @Override
        public double discount() {
            return 1.0;
        }

        @Override
        public FactoredRepresenter<State, ? extends FactoredRepresentation<State>> base_repr() {
            return base_repr;
        }

        @Override
        public Representer<State, ? extends Representation<State>> action_repr() {
            return action_repr;
        }

        @Override
        public State initialState() {
            return new State(s0);
        }

        @Override
        public Iterable<Action> actions(final State s) {
            final Actions actions = new Actions(s.rng);
            actions.setState(s, 0L);
            return Fn.in(actions);
        }

        @Override
        public State sampleTransition(final State s, final Action a) {
            sample_count += 1;

            final State copy = new State(s);
            a.create().doAction(copy);

            final PostDynamicsAction post = new PostDynamicsAction(s.rng.nextInt(s.W));
            post.doAction(copy);

            return copy;
        }

        @Override
        public double reward(final State s) {
            if (s.path == Path.Dead) {
                return -30;
            } else {
                return 0;
            }
        }

        @Override
        public double reward(final State s, final Action a) {
            return a.reward();
        }

        @Override
        public int sampleCount() {
            return sample_count;
        }

        @Override
        public double heuristic(final State s) {
            // TODO Auto-generated method stub
            return 0;
        }

        @Override
        public RandomGenerator rng() {
            // TODO Auto-generated method stub
            return null;
        }

        @Override
        public void resetSampleCount() {
            // TODO Auto-generated method stub

        }
    }

    // -----------------------------------------------------------------------

    public static void main(final String[] argv) throws NumberFormatException, IOException {
        final RandomGenerator rng = new MersenneTwister(42);
        final int L = 7;
        final int W = 4;
        final int F = 5;
        final State s = new State(rng, L, W, F);
        final Simulator sim = new Simulator(s, rng);

        final Actions actions = new Actions(rng);

        while (!sim.state().isTerminal()) {
            System.out.println(s);
            actions.setState(sim.state(), sim.t());
            final ArrayList<Action> action_list = Fn.takeAll(actions);
            for (int i = 0; i < action_list.size(); ++i) {
                System.out.println(i + ": " + action_list.get(i));
            }
            System.out.print(">>> ");
            final BufferedReader cin = new BufferedReader(new InputStreamReader(System.in));
            final int choice = Integer.parseInt(cin.readLine());
            final Action a = action_list.get(choice);
            sim.takeAction(new JointAction<Action>(a));
            System.out.println("Reward: " + Arrays.toString(sim.reward()));
        }
    }
}