br.unicamp.cst.learning.QLearning.java Source code

Java tutorial

Introduction

Here is the source code for br.unicamp.cst.learning.QLearning.java

Source

/*******************************************************************************
 * Copyright (c) 2012  DCA-FEEC-UNICAMP
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the GNU Lesser Public License v3
 * which accompanies this distribution, and is available at
 * http://www.gnu.org/licenses/lgpl.html
 * 
 * Contributors:
 *     K. Raizer, A. L. O. Paraense, R. R. Gudwin - initial API and implementation
 *     E. M. Froes - documentation
 ******************************************************************************/

package br.unicamp.cst.learning;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.Random;
import org.json.JSONException;
import org.json.JSONObject;

/**
 * The update equation for TD Q-learning is:
 * Q(s,a)= Q(s,a) + alpha * (r(s) + gamma * Max(s', all actions) - Q(s,a))
 * 
 * which is calculated whenever action a is executed in state s leading to state s'.
 * acting randomly on some fraction of steps, where the fraction decreases over time,
 *   we can dispense with keeping statistics about taken actions.
 *   
 * [1] Ganapathy 2009 "Utilization of Webots and the Khepera II as a Platform for Neural Q-Learning Controllers"
 * [2] http://people.revoledu.com/kardi/tutorial/ReinforcementLearning/Q-Learning-Matlab.htm
 * [3] Norvig
 * @author klaus
 *
 */

public class QLearning {

    private boolean showDebugMessages = false;
    private ArrayList<String> statesList;
    private ArrayList<String> actionsList;
    private String fileName = "QTable.txt";
    private HashMap<String, HashMap<String, Double>> Q;

    private double e = 0.1; //Probability of choosing the best action instead of a random one
    private double alpha = 0.5; //Here, alpha is the learning rate parameter
    private double gamma = 0.9; //discount factor
    private double b = 0.95; // probability of random action choice deciding for the previous action instead of randomly choosing one from the action list
    //   private int statesCount,actionsCount;
    private String s = "", a = "", sl = "", al = "";
    private double reward = 0;
    private Random r = new Random();

    /**
     * Default Constructor.
     */
    public QLearning() {
        statesList = new ArrayList<String>();
        actionsList = new ArrayList<String>();
        Q = new HashMap<String, HashMap<String, Double>>(); // Q learning
    }

    /**
     * This method set Q value with parameters Qval, state and action.
     * @param Qval
     * @param state
     * @param action 
     */
    public void setQ(double Qval, String state, String action) {
        HashMap<String, Double> tempS = this.Q.get(state);
        if (tempS != null) {
            //This state already exists, So I have to check if it already contains this action
            if (tempS.get(action) != null) {
                //the action already exists, So I just update it to the new one
                tempS.put(action, Qval);
            } else {
                if (!actionsList.contains(action)) {//TODO something wicked here. I shouldn't need to perform this test...
                    actionsList.add(action);
                }
                tempS.put(action, Qval);
            }
        } else {
            //this state doesn't exist yet, so I must create it and populate it with nActions-1 valued 0 and one action valued Qval
            HashMap<String, Double> tempNew = new HashMap<String, Double>();
            tempNew.put(action, Qval);
            statesList.add(state);
            this.Q.put(state, tempNew);
        }
    }

    /**
    * Returns the utility value Q related to the given state/action pair
    * @param state
    * @param action
    * @return
    */
    public double getQ(String state, String action) {
        double dQ = 0;
        if (!(Q.get(state) == null || Q.get(state).get(action) == null)) {
            dQ = Q.get(state).get(action);
        }
        return dQ;
    }

    /**
    * Returns the maximum Q value for sl. 
    * @param sl
    * @return Q Value
    */
    public double maxQsl(String sl) {
        double maxQinSl = 0;
        String maxAl = "";
        double val = 0;
        if (this.Q.get(sl) != null) {
            HashMap<String, Double> tempSl = this.Q.get(sl);
            ArrayList<String> tempA = new ArrayList<String>();
            tempA.addAll(this.actionsList);

            // Finds out the action with maximum value for sl
            Iterator<Entry<String, Double>> it = tempSl.entrySet().iterator();

            while (it.hasNext()) {
                Entry<String, Double> pairs = it.next();
                val = pairs.getValue();
                tempA.remove(pairs.getKey());
                if (val > maxQinSl) {
                    maxAl = pairs.getKey();
                    maxQinSl = val;
                }
            }
            if (!tempA.isEmpty() && maxQinSl < 0) {
                maxQinSl = 0;
            } //Assigning 0 to unknown state/action pair
        }
        return maxQinSl;
    }

    /**
    * This methods is responsible for update the state.
    * @param stateIWas state I was previously
    * @param actionIDid action I did while at the previous state
    * @param rewardIGot reward I got after moving from previous state to the present one
    */
    public void update(String stateIWas, String actionIDid, double rewardIGot) {
        //which is calculated whenever action a is executed in state s leading to state s'
        this.sl = stateIWas;
        this.al = actionIDid;

        if (!a.equals("") && !s.equals("")) {
            //         if(!s.equals(sl)){//Updates only if state changes, is this correct?
            double Qas = this.getQ(s, a);
            double MaxQ = this.maxQsl(this.sl);
            double newQ = Qas + alpha * (rewardIGot + gamma * MaxQ - Qas); //TODO  not sure if its reward or rewardIGot
            this.setQ(newQ, s, a);
            //            System.out.println("== Update ============");
            //            System.out.println("a: "+a+"  s: "+s+"  al: "+al+"  sl: "+sl+"  Qas: "+Qas+"  MaxQ: "+MaxQ+"  newQ: "+newQ);
            //            System.out.println("======================");
            //            this.printQ();
            //         }
        }

        a = this.al;
        s = this.sl;
        reward = rewardIGot;
    }

    /**
    * This print Q values.
    */
    public void printQ() {
        System.out.println("------ Printed Q -------");
        Iterator<Entry<String, HashMap<String, Double>>> itS = this.Q.entrySet().iterator();
        while (itS.hasNext()) {
            Entry<String, HashMap<String, Double>> pairs = itS.next();
            HashMap<String, Double> tempA = pairs.getValue();
            Iterator<Entry<String, Double>> itA = tempA.entrySet().iterator();
            double val = 0;
            System.out.print("State(" + pairs.getKey() + ") actions: ");
            while (itA.hasNext()) {
                Entry<String, Double> pairsA = itA.next();
                val = pairsA.getValue();
                System.out.print("[" + pairsA.getKey() + ": " + val + "] ");
            }
            System.out.println("");
        }

        System.out.println("----------------------------");
    }

    /**
     *  Store Q values to file using JSON structure.
     */
    public void storeQ() {
        String textQ = "";
        //      JSONArray actionValueArray=new JSONArray();
        JSONObject actionValuePair = new JSONObject();

        JSONObject actionsStatePair = new JSONObject();
        //      JSONArray statesArray= new JSONArray();
        try {
            Iterator<Entry<String, HashMap<String, Double>>> itS = this.Q.entrySet().iterator();
            while (itS.hasNext()) {
                Entry<String, HashMap<String, Double>> pairs = itS.next();
                HashMap<String, Double> tempA = pairs.getValue();
                Iterator<Entry<String, Double>> itA = tempA.entrySet().iterator();
                double val = 0;
                //            System.out.print("State("+pairs.getKey()+") actions: ");
                actionValuePair = new JSONObject();
                while (itA.hasNext()) {
                    Entry<String, Double> pairsA = itA.next();
                    val = pairsA.getValue();
                    actionValuePair.put(pairsA.getKey(), val);
                }
                //            System.out.println(actionsStatePair+" "+pairs.getKey()+" "+actionValuePair);
                actionsStatePair.put(pairs.getKey(), actionValuePair);
            }

        } catch (JSONException e) {
            e.printStackTrace();
        }

        //use buffering
        Writer output;
        try {
            output = new BufferedWriter(new FileWriter(fileName));

            try {
                //FileWriter always assumes default encoding is OK!
                output.write(actionsStatePair.toString());
            } finally {
                output.close();
            }

        } catch (IOException e) {
            e.printStackTrace();
        }

        //      System.out.println("------ Stored Q -------");
        //      System.out.println("Q: "+actionsStatePair.toString());
        //      System.out.println("----------------------------");
    }

    /**
     *  Recover Q values from file in JSON structure.
     */
    public void recoverQ() {
        //...checks on aFile are elided
        StringBuilder contents = new StringBuilder();

        try {
            //use buffering, reading one line at a time
            //FileReader always assumes default encoding is OK!
            BufferedReader input = new BufferedReader(new FileReader(fileName));
            try {
                String line = null; //not declared within while loop
                /*
                 * readLine is a bit quirky :
                 * it returns the content of a line MINUS the newline.
                 * it returns null only for the END of the stream.
                 * it returns an empty String if two newlines appear in a row.
                 */
                while ((line = input.readLine()) != null) {
                    contents.append(line);
                    //contents.append(System.getProperty("line.separator"));
                }
            } finally {
                input.close();
            }
        } catch (IOException ex) {
            ex.printStackTrace();
        }

        //      actionValuePair.put(pairsA.getKey(), val);
        //   }         

        //      System.out.println("contents: "+contents.toString());
        JSONObject actionsStatePairs;
        try {
            actionsStatePairs = new JSONObject(contents.toString());
            //         System.out.println("actionsStatePairs.toString(): "+actionsStatePairs.toString());

            Iterator itS = actionsStatePairs.keys();
            while (itS.hasNext()) {
                String state = itS.next().toString();
                //            System.out.println("itS.next(): "+state);
                JSONObject pairAS = (JSONObject) actionsStatePairs.get(state);

                Iterator itA = pairAS.keys();
                while (itA.hasNext()) {
                    String action = itA.next().toString();

                    double value = pairAS.getDouble(action);

                    this.setQ(value, state, action);

                }
            }

        } catch (JSONException e1) {

            e1.printStackTrace();
        }

    }

    /**
     * Clear Q values.
     */
    public void clearQ() {
        this.Q.clear();
    }

    /**
     * Gets alpha value.
     * @return 
     */
    public double getAlpha() {
        return alpha;
    }

    /**
    *  Sets the learning rate parameter alpha.
    *  Should be between 0 and 1
    * @param alpha
    */
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    /**
     * Gets gamma value.
     * @return 
     */
    public double getGamma() {
        return gamma;
    }

    /**
    * Sets the discount factor.
    * Should be between 0 and 1.
    * @param gamma
    */
    public void setGamma(double gamma) {
        this.gamma = gamma;
    }

    /**
    * Selects the best action for this state with probability "e", 
    * and a random one with probability (1-e) 
    *  If a given state has no record of one or more actions, it will consider them as valued 0.
    * @param state
    * @param e
     * @return selectedAction
    */
    public String getAction(String state) {//TODO should improve this. It should consider all non explored actions as being equally 0 for all purposes
        //      System.out.println("Inside get action");
        String selectedAction = null;
        if (r.nextDouble() <= e) { //TODO Use boltzmann distribution here?
            //         if(ql.getAction(stringState)!=null){
            //            action=ql.getAction(stringState);//
            //-----

            if (this.Q.get(state) != null) {
                ArrayList<String> actionsLeft = new ArrayList<String>();
                actionsLeft.addAll(this.actionsList);
                HashMap<String, Double> actionsQ = this.Q.get(state);
                double bestQval = -Double.POSITIVE_INFINITY;
                Iterator<Entry<String, Double>> it = actionsQ.entrySet().iterator();

                while (it.hasNext()) {
                    Entry<String, Double> pairs = it.next();
                    double qVal = pairs.getValue();
                    String qAct = pairs.getKey();
                    if (qVal > bestQval) {
                        bestQval = qVal;
                        selectedAction = qAct;
                    }
                    actionsLeft.remove(qAct);
                }
                if ((bestQval < 0) && (actionsLeft.size() > 0)) {
                    //this means we should randomly choose from the other actions;
                    selectedAction = selectRandomAction(actionsLeft);
                }
                if (showDebugMessages) {
                    System.out.println("Selected the best available action.");
                }
            } else {
                //            System.out.println("Inside else null");
                //               selectedAction=null;
                selectedAction = selectRandomAction(actionsList);
                if (showDebugMessages) {
                    System.out.println("Selected a random action because there was no available suggestion.");
                }
            }

            //         }else{
            //            action=selectRandomAction();
            //         }
        } else {
            if (showDebugMessages) {
                System.out.println("Naturally selected a random action.");
            }
            selectedAction = selectRandomAction(actionsList);
        }
        return selectedAction;
    }

    /**
     * Gets states list.
     * @return the statesList
     */
    public ArrayList<String> getStatesList() {
        return statesList;
    }

    /**
     * Sets action list.
     * @param statesList the statesList to set
     */
    public void setStatesList(ArrayList<String> statesList) {
        this.statesList = statesList;
    }

    /**
     * Gets the action list.
     * @return the actionsList
     */
    public ArrayList<String> getActionsList() {
        return actionsList;
    }

    /**
     * This sets action list.
     * @param actionsList the actionsList to set
     */
    public void setActionsList(ArrayList<String> actionsList) {
        this.actionsList = new ArrayList<String>();
        this.actionsList.addAll(actionsList);
    }

    /**
     * Gets E value.
     * @return e
     */
    public double getE() {
        return e;
    }

    /**
     * Sets the chances of getting the best possible action.
     * With e=0.9 for instance, there is a .9 chance of getting the best action for the given state, and .1 probability of getting a random action.
     * 
     * @param e the e to set
     */
    public void setE(double e) {
        this.e = e;
    }

    /**
     * Gets all actions from state.
     * @param state
     * @return actons
     */
    public String getAllActionsFromState(String state) {
        String actions = "";
        if (this.Q.get(state) != null) {
            HashMap<String, Double> actionsH = this.Q.get(state);

            Iterator<Entry<String, Double>> it = actionsH.entrySet().iterator();

            while (it.hasNext()) {
                Entry<String, Double> pairs = it.next();
                double qVal = (Double) pairs.getValue();
                String act = (String) pairs.getKey();
                actions = actions + "{" + act + ":" + qVal + "} ";

            }

        } else {
            actions = "{}";
        }

        return actions;

    }

    /**
     * Select randomically a action.
     * @param localActionsList
     * @return actionR 
     */
    private String selectRandomAction(ArrayList<String> localActionsList) {
        String actionR = this.a;
        double pseudoRandomNumber = r.nextDouble();

        if ((pseudoRandomNumber >= b) || actionR == null || actionR.equals("")) {
            int actionI = r.nextInt(localActionsList.size());
            actionR = localActionsList.get(actionI);
        }

        //      System.out.println("INSIDE RANDOM: "+actionR);

        return actionR;//TODO should I use boltzman distribution?

        /*
         * Simulating a Pareto random variable. The Pareto distribution is often used to model 
         * insurance claims damages, financial option holding times, and Internet traffic activity.
         *  The probability that a Pareto random variable with parameter a is less than x is
         *   F(x) = 1 - (1 + x)-a for x >= 0. To generate a random deviate from the distribution,
         *    use the inverse function method: output (1-U)-1/a - 1, where U is a uniform random number between 0 and 1.
         */

    }

    /**
     * This  method "maxwellBoltzmann()" returns a pseudo-random value from a Maxwell-Boltzmann distribution
     * with parameter sigma. Take the sum of the squares of three gaussian random variables 
     * with mean 0, and standard deviation sigma, and return the square root.
     * double  e = random.nextGaussian();  // Gaussian with mean 0 and stddev = 1
     * 
     * @return sum
     */
    public double maxwellBoltzmann() {
        double sum = 0;
        sum = sum + Math.pow(r.nextGaussian(), 2);
        sum = sum + Math.pow(r.nextGaussian(), 2);
        sum = sum + Math.pow(r.nextGaussian(), 2);

        sum = Math.sqrt(sum);

        return sum;

    }

    /**
     * Gets B value.
     * @return the b
     */
    public double getB() {
        return b;
    }

    /**
     * Sets B value.
     * @param b the b to set
     */
    public void setB(double b) {
        this.b = b;
    }

    /**
     * Gets S value.
     * @return the s
     */
    public String getS() {
        return s;
    }

    /**
     * Sets S value.
     * @param s the s to set
     */
    public void setS(String s) {
        this.s = s;
    }

    /**
     * Gets A value.
     * @return the a
     */
    public String getA() {
        return a;
    }

    /**
     * Sets A value.
     * @param a the a to set
     */
    public void setA(String a) {
        this.a = a;
    }

    /**
     * Gets SL value.
     * @return the sl
     */
    public String getSl() {
        return sl;
    }

    /**
     * Sets SL value.
     * @param sl the sl to set
     */
    public void setSl(String sl) {
        this.sl = sl;
    }

    /**
     * Gets the AL value.
     * @return the al
     */
    public String getAl() {
        return al;
    }

    /**
     * Sets AL value.
     * @param al the al to set
     */
    public void setAl(String al) {
        this.al = al;
    }

    /**
     * Gets all Q values.
     * @return 
     */
    public HashMap getAllQ() {

        return this.Q;
    }
}