cerrla.modular.ModularPolicy.java Source code

Java tutorial

Introduction

Here is the source code for cerrla.modular.ModularPolicy.java

Source

/*
 *    This file is part of the CERRLA algorithm
 *
 *    CERRLA is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    CERRLA is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with CERRLA. If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    src/cerrla/modular/ModularPolicy.java
 *    Copyright (C) 2012 Samuel Sarjant
 */
package cerrla.modular;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import org.apache.commons.collections.BidiMap;
import org.apache.commons.collections.bidimap.DualHashBidiMap;

import cerrla.LocalCrossEntropyDistribution;
import cerrla.ProgramArgument;

import jess.Rete;

import relationalFramework.FiredAction;
import relationalFramework.PolicyActions;
import relationalFramework.RelationalArgument;
import relationalFramework.RelationalPolicy;
import relationalFramework.RelationalRule;
import rrlFramework.RRLExperiment;
import rrlFramework.RRLObservations;
import util.ArgumentComparator;
import util.GoalConditionComparator;
import util.MultiMap;
import util.Recursive;

public class ModularPolicy extends RelationalPolicy {
    /** The minimum 'goal-not-achieved' value. */
    private static final double MINIMUM_REWARD = -100000;

    private static final long serialVersionUID = 7855536761222318011L;

    /** The reward received at every step by the sub-goal. */
    private static final double SUB_GOAL_REWARD = -1;

    /** The distribution that created this modular policy. */
    private LocalCrossEntropyDistribution ceDistribution_;

    /** The collection of policies this policy directly contains. */
    private MultiMap<RelationalRule, ModularSubGoal> childrenPolicies_;

    /** The goal replacements for this episode ('?G_0 -> a' format). */
    private transient BidiMap episodeGoalReplacements_;

    /** The reward received this episode. */
    private transient double[] episodeReward_;

    /** If this learning episode has started where the goal is unachieved. */
    private transient boolean episodeStarted_;

    /** Gets the rules that fired last step. */
    private transient Set<RelationalRule> firedLastStep_;

    /** If the internal goal of this policy is currently achieved. */
    private transient boolean goalAchievedCurrently_;

    /** If the internal goal of this policy was achieved this episode. */
    private transient boolean goalAchievedEpisode_;

    /** A map for transforming goal replacements into the appropriate args. */
    private Map<RelationalArgument, RelationalArgument> moduleParamReplacements_;

    /** The rewards this policy achieved for each episode. */
    private ArrayList<double[]> policyRewards_;

    /** The rules that have fired. */
    private Set<RelationalRule> triggeredRules_;

    /** A unique ID for this modular policy. */
    private String uniqueID_;

    /** If this policy has been evaluated yet. */
    private boolean isEvaluated_;

    /**
     * A constructor for a blank modular policy.
     * 
     * @param policyGenerator
     *            The generator that created this policy.
     */
    public ModularPolicy(LocalCrossEntropyDistribution policyGenerator) {
        super();
        ceDistribution_ = policyGenerator;
        triggeredRules_ = new HashSet<RelationalRule>();
        childrenPolicies_ = MultiMap.createListMultiMap();
        policyRewards_ = new ArrayList<double[]>();
        isEvaluated_ = false;

        uniqueID_ = ceDistribution_.generateUniquePolicyID();
    }

    /**
     * A constructor for a new policy using the same rules from an old policy.
     * 
     * @param policy
     *            The old policy.
     */
    public ModularPolicy(ModularPolicy policy) {
        this(policy.ceDistribution_);
        for (PolicyItem reo : policy.policyRules_)
            policyRules_.add(reo);

        policySize_ = policy.policySize_;
        childrenPolicies_ = new MultiMap<RelationalRule, ModularSubGoal>(policy.childrenPolicies_);
        isEvaluated_ = false;
    }

    /**
     * Creates a new modular policy from an existing basic relational policy.
     * 
     * @param basicPolicy
     *            The basic policy with rules to transfer to this policy.
     * @param policyGenerator
     *            The generator that created this policy.
     */
    public ModularPolicy(RelationalPolicy newPol, LocalCrossEntropyDistribution policyGenerator) {
        this(policyGenerator);
        policySize_ = newPol.size();
        isEvaluated_ = false;

        // Add the rules, creating ModularHoles where appropriate.
        SortedSet<GoalCondition> subGoals = new TreeSet<GoalCondition>(new GoalConditionComparator());
        for (PolicyItem reo : newPol.getRules()) {
            if (reo instanceof RelationalRule) {
                RelationalRule rule = (RelationalRule) reo;
                policyRules_.add(reo);

                // Checking for sub-goals
                // Only have each sub-goal once
                if (ProgramArgument.USE_MODULES.booleanValue()) {
                    Collection<SpecificGoalCondition> goalConds = rule.getSpecificSubGoals();
                    for (GoalCondition gc : goalConds) {
                        if (!subGoals.contains(gc)) {
                            ModularSubGoal subGoal = new ModularSubGoal(gc, rule);
                            subGoals.add(gc);
                            policyRules_.add(subGoal);
                            childrenPolicies_.put(rule, subGoal);
                        }
                    }
                }

                // General sub-goals
                if (ProgramArgument.USE_GENERAL_MODULES.booleanValue()) {
                    Collection<GeneralGoalCondition>[] generalisedConds = rule.getGeneralisedConditions();
                    // Add all general conditions, and fill in the blanks
                    // when necessary.
                    for (GoalCondition gc : generalisedConds[0]) {
                        if (!subGoals.contains(gc)) {
                            ModularSubGoal subGoal = new ModularSubGoal(gc, rule);
                            subGoals.add(gc);
                            policyRules_.add(subGoal);
                            childrenPolicies_.put(rule, subGoal);
                        }
                    }
                    for (GoalCondition gc : generalisedConds[1]) {
                        if (!subGoals.contains(gc)) {
                            ModularSubGoal subGoal = new ModularSubGoal(gc, rule);
                            subGoals.add(gc);
                            policyRules_.add(subGoal);
                            childrenPolicies_.put(rule, subGoal);
                        }
                    }
                }
            }
        }
    }

    /**
     * Evaluates the rules of this policy (with recursion).
     * 
     * @param observations
     *            The relational observations for the state.
     * @param policyActions
     *            The collection to add the actions to.
     * @param activatedActions
     *            The actions the RLGG rules return.
     * @param actionsFound
     *            The number of actions found so far.
     * @param actionsRequired
     *            The number of actions required to return.
     * @return The resultant actions.
     * @throws Exception
     *             Should something go awry...
     */
    private void evaluateInternalPolicy(RRLObservations observations, PolicyActions policyActions,
            MultiMap<String, String[]> activatedActions, int actionsFound, int actionsRequired) throws Exception {
        firedLastStep_.clear();

        // Run the cover state method to possibly scan this state (depending on
        // current RLGG rules and scan intervals).
        List<RelationalRule> coveredRules = ceDistribution_.coverState(this, observations, activatedActions,
                episodeGoalReplacements_);
        // If the policy is empty, store the rules in it.
        if (coveredRules != null && !coveredRules.isEmpty()) {
            Collections.shuffle(coveredRules, RRLExperiment.random_);
            // Add any new rules to the policy
            for (RelationalRule gr : coveredRules) {
                policyRules_.add(gr);
                policySize_++;
            }
            // Add parameters to the rule.
            parameterArgs(transformGoalReplacements(observations.getGoalReplacements()));
        }

        // If the goal has been achieved, don't evaluate this policy
        if (goalAchievedCurrently_)
            return;

        // Evaluate the rules/policies recursively.
        Rete state = observations.getState();

        Iterator<PolicyItem> iter = policyRules_.iterator();
        while (iter.hasNext() && actionsFound < actionsRequired) {
            Object polObject = iter.next();
            if (polObject instanceof RelationalRule) {
                // Evaluate the rule
                RelationalRule polRule = (RelationalRule) polObject;
                Collection<FiredAction> firedActions = evaluateRule(polRule, state,
                        observations.getValidActions(polRule.getActionPredicate()), null, true);
                policyActions.addFiredRule(firedActions, this);
                actionsFound += firedActions.size();

                // If this rule created a sub-goal, mark the goal achieved.
                if (childrenPolicies_.containsKey(polRule)) {
                    for (ModularSubGoal modSubGoal : childrenPolicies_.get(polRule)) {
                        if (!firedActions.isEmpty())
                            modSubGoal.setGoalAchieved(true);
                        else
                            modSubGoal.setGoalAchieved(false);
                    }
                }
            } else if (polObject instanceof ModularSubGoal) {
                // Evaluate the internal policy.
                ModularPolicy internalPolicy = ((ModularSubGoal) polObject).getModularPolicy();
                if (internalPolicy != null)
                    internalPolicy.evaluateInternalPolicy(observations, policyActions, activatedActions,
                            actionsFound, actionsRequired);
            }
        }
    }

    /**
     * Recursively prints out the policy, incrementing relational policies along
     * the way.
     * 
     * @param buffer
     *            The buffer to print to.
     * @param depth
     *            The amount of incrementing to do.
     * @return The String version of this policy.
     */
    @Recursive
    private String recursePolicyToString(StringBuffer buffer, int depth, boolean onlyTriggered) {
        for (PolicyItem reo : policyRules_) {
            if (reo instanceof RelationalRule) {
                // If only triggered rules, just print rules that were
                // triggered.
                if (!onlyTriggered || triggeredRules_.contains(reo)) {
                    for (int i = 0; i < depth; i++) {
                        if (i < depth - 1)
                            buffer.append(" ");
                        else
                            buffer.append(" |" + ceDistribution_.getGoalCondition() + "|");
                    }
                    buffer.append(((RelationalRule) reo).toNiceString(moduleParamReplacements_));
                    buffer.append("\n");
                }
            } else if (reo instanceof ModularSubGoal) {
                ModularPolicy internalPolicy = ((ModularSubGoal) reo).getModularPolicy();
                if (internalPolicy != null)
                    internalPolicy.recursePolicyToString(buffer, depth + 1, onlyTriggered);
            }
        }
        return buffer.toString();
    }

    /**
     * Transforms the given goal replacements into potentially different (but
     * always smaller/equal size) replacements based on how this policy is
     * defined.
     * 
     * @param originalGoalReplacements
     *            The original goal replacements to modify.
     * @return The transformed replacements. Always smaller/equal size and using
     *         the same args.
     */
    private BidiMap transformGoalReplacements(BidiMap originalGoalReplacements) {
        // Modify the goal replacements based on the moduleParamReplacements
        BidiMap goalReplacements = originalGoalReplacements;
        if (moduleParamReplacements_ != null && !moduleParamReplacements_.isEmpty()) {
            // Swap any terms shown in the replacements
            BidiMap modGoalReplacements = new DualHashBidiMap();
            for (RelationalArgument ruleParam : moduleParamReplacements_.keySet()) {
                RelationalArgument goalParam = moduleParamReplacements_.get(ruleParam);
                modGoalReplacements.put(goalReplacements.getKey(goalParam), ruleParam);
            }

            goalReplacements = modGoalReplacements;
        }
        return goalReplacements;
    }

    /**
     * Adds a rule to the set of triggered rules. Some circumstances may forbid
     * the rule being added.
     * 
     * @param rule
     *            The rule to be added.
     * @return True if the rule was successfully added, or is already present.
     *         False if the rule was not allowed to be added.
     */
    public boolean addTriggeredRule(RelationalRule rule) {
        triggeredRules_.add(rule);
        firedLastStep_.add(rule);
        return true;
    }

    /**
     * Removes all child policies (which can be put back in later).
     */
    public void clearChildren() {
        for (ModularSubGoal subGoal : childrenPolicies_.values())
            subGoal.setModularPolicy(null);
    }

    /**
     * Notes the final reward received for this episode.
     * 
     * @param reward
     *            The reward received for the episode.
     * @return True if the modular policy needs to be regenerated due to a
     *         part(s) of it being fully tested.
     */
    @Recursive
    public boolean endEpisode() {
        // If the episode never started, record nothing
        if (!episodeStarted_)
            return false;

        // Modify the reward if the goal hasn't been achieved if a sub-goal
        // generator
        if (!ceDistribution_.getGoalCondition().isMainGoal() && !goalAchievedEpisode_) {
            if (episodeReward_ == null)
                episodeReward_ = new double[2];
            episodeReward_[0] = MINIMUM_REWARD;
            episodeReward_[1] = MINIMUM_REWARD;
        }

        // Note the episode reward in the generator.
        if (ceDistribution_.getGoalCondition().isMainGoal() || episodeReward_ != null) {
            policyRewards_.add(episodeReward_);
            episodeReward_ = null;
        }

        // End episode for all children.
        boolean regeneratePolicy = false;
        for (ModularSubGoal child : childrenPolicies_.values()) {
            ModularPolicy childPol = child.getModularPolicy();
            if (childPol != null)
                regeneratePolicy |= childPol.endEpisode();
        }

        // Check if sample needs to be recorded
        if (policyRewards_.size() >= ceDistribution_.getPolicyRepeats()) {
            // Record the sample.
            ceDistribution_.recordSample(this, policyRewards_);
            regeneratePolicy = true;
        }

        return regeneratePolicy;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (!super.equals(obj))
            return false;
        if (getClass() != obj.getClass())
            return false;
        ModularPolicy other = (ModularPolicy) obj;
        if (moduleParamReplacements_ == null) {
            if (other.moduleParamReplacements_ != null)
                return false;
        } else if (!moduleParamReplacements_.equals(other.moduleParamReplacements_))
            return false;
        if (uniqueID_ == null) {
            if (other.uniqueID_ != null)
                return false;
        } else if (!uniqueID_.equals(other.uniqueID_))
            return false;
        return true;
    }

    @Override
    public PolicyActions evaluatePolicy(RRLObservations observations, int actionsReturned) {
        isEvaluated_ = true;
        PolicyActions policyActions = new PolicyActions();
        MultiMap<String, String[]> activatedActions = MultiMap
                .createSortedSetMultiMap(ArgumentComparator.getInstance());
        int actionsReturnedModified = (actionsReturned <= -1) ? Integer.MAX_VALUE : actionsReturned;
        Rete state = observations.getState();

        try {
            // First evaluate the RLGG rules (if any). If the actions there
            // don't match up to the activated actions, covering will be
            // required.
            for (RelationalRule rlgg : ceDistribution_.getPolicyGenerator().getRLGGRules().values()) {
                SortedSet<String[]> rlggActions = new TreeSet<String[]>(ArgumentComparator.getInstance());
                evaluateRule(rlgg, state, observations.getValidActions(rlgg.getActionPredicate()), rlggActions,
                        false);
                activatedActions.putCollection(rlgg.getActionPredicate(), rlggActions);
            }

            // Next, evaluate the rest of the policy until an adequate number of
            // rules are evaluated (usually 1 or all; may be the entire policy).
            evaluateInternalPolicy(observations, policyActions, activatedActions, 0, actionsReturnedModified);
        } catch (Exception e) {
            e.printStackTrace();
        }

        return policyActions;
    }

    /**
     * Gets all the policies this policy contains (recursively), including
     * itself.
     * 
     * @param undertestedOnly
     *            If only collecting undertested policies.
     * @param firedOnly
     *            If only noting fired policies.
     * @param recursiveCollection
     *            A set of the policies collected recursively.
     * @return A collection of policies of size at least 1.
     */
    public Collection<ModularPolicy> getAllPolicies(boolean undertestedOnly, boolean firedOnly,
            Collection<ModularPolicy> recursiveCollection) {
        // Initialise recursive collection.
        if (recursiveCollection == null)
            recursiveCollection = new HashSet<ModularPolicy>();
        if (undertestedOnly && shouldRegenerate())
            return recursiveCollection;
        if (firedOnly && triggeredRules_.isEmpty())
            return recursiveCollection;

        recursiveCollection.add(this);

        // Run through all children policies
        for (ModularSubGoal child : childrenPolicies_.values()) {
            ModularPolicy childPol = child.getModularPolicy();
            if (childPol != null)
                childPol.getAllPolicies(undertestedOnly, firedOnly, recursiveCollection);
        }

        return recursiveCollection;
    }

    /**
     * Gets the rules that fired from this policy.
     * 
     * @return The rules that fired in this policy
     */
    public Set<RelationalRule> getFiringRules() {
        return triggeredRules_;
    }

    public LocalCrossEntropyDistribution getLocalCEDistribution() {
        return ceDistribution_;
    }

    /**
     * Gets the modular replacement map.
     * 
     * @return The replacement map for this policy.
     */
    public Map<RelationalArgument, RelationalArgument> getModularReplacementMap() {
        return moduleParamReplacements_;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = super.hashCode();
        result = prime * result + ((moduleParamReplacements_ == null) ? 0 : moduleParamReplacements_.hashCode());
        result = prime * result + ((uniqueID_ == null) ? 0 : uniqueID_.hashCode());
        return result;
    }

    public boolean isFresh() {
        return !isEvaluated_;
    }

    /**
     * If the goal has been achieved.
     * 
     * @return True if the goal has been achieved.
     */
    public boolean isGoalCurrentlyAchieved() {
        return goalAchievedCurrently_;
    }

    /**
     * Notes the environment reward if main modular policy, otherwise it uses an
     * internal reward.
     */
    @Recursive
    public boolean noteStepReward(double[] reward) {
        // Note reward if a rule in this policy fired (or it's the main policy).
        boolean noteReward = ceDistribution_.getGoalCondition().isMainGoal() || !firedLastStep_.isEmpty();

        // Drop down and reward from the bottom up.
        for (ModularSubGoal child : childrenPolicies_.values()) {
            ModularPolicy childPol = child.getModularPolicy();
            if (childPol != null)
                noteReward |= childPol.noteStepReward(reward);
        }

        // Only note the reward if a rule within this policy fired.
        if (noteReward) {
            // If this is an unachieved sub-goal, note reward.
            if (!ceDistribution_.getGoalCondition().isMainGoal()) {
                // If the episode has started and the goal hasn't been achieved,
                // note reward.
                if (episodeStarted_ && !goalAchievedEpisode_) {
                    if (episodeReward_ == null)
                        episodeReward_ = new double[2];
                    episodeReward_[0] += SUB_GOAL_REWARD;
                    episodeReward_[1] += SUB_GOAL_REWARD;
                }
            } else {
                if (episodeReward_ == null)
                    episodeReward_ = new double[2];
                episodeReward_[0] += reward[0];
                episodeReward_[1] += reward[1];
            }
        }

        return noteReward;
    }

    /**
     * Sets the goal (episodic and current) achieved.
     */
    public void setGoalAchieved() {
        goalAchievedCurrently_ = true;
        if (episodeStarted_)
            goalAchievedEpisode_ = true;
    }

    /**
     * Sets the current goal unachieved.
     * 
     * @param startEpisode
     *            If the modular sub-goal episode should be started (because the
     *            goal is currently unmet).
     */
    public void setGoalUnachieved(boolean startEpisode) {
        goalAchievedCurrently_ = false;
        episodeStarted_ |= startEpisode;
    }

    /**
     * Sets the modular replacements for this policy.
     * 
     * @param paramReplacementMap
     *            The modular replacements to set.
     */
    public void setModularParameters(Map<RelationalArgument, RelationalArgument> paramReplacementMap) {
        moduleParamReplacements_ = paramReplacementMap;
    }

    public void setParameters(BidiMap goalArgs) {
        BidiMap transformedArgs = transformGoalReplacements(goalArgs);
        episodeGoalReplacements_ = transformedArgs;
        for (PolicyItem obj : policyRules_) {
            obj.setParameters(goalArgs);
        }
    }

    /**
     * If this policy should be replaced with another policy.
     * 
     * @return True if this policy has collected enough rewards to be noted.
     */
    public boolean shouldRegenerate() {
        return policyRewards_.size() >= ceDistribution_.getPolicyRepeats();
    }

    public int size() {
        return policySize_;
    }

    /**
     * Starts a new episode, so reward observation begins on a new episode.
     */
    @Recursive
    public void startEpisode() {
        episodeReward_ = null;
        if (ceDistribution_.getGoalCondition().isMainGoal())
            episodeStarted_ = true;
        else
            episodeStarted_ = false;
        goalAchievedEpisode_ = false;
        goalAchievedCurrently_ = false;
        firedLastStep_ = new HashSet<RelationalRule>();

        ceDistribution_.startEpisode();

        // Start episode for all children
        for (ModularSubGoal child : childrenPolicies_.values()) {
            ModularPolicy childPol = child.getModularPolicy();
            if (childPol != null)
                childPol.startEpisode();
        }
    }

    public String toNiceString() {
        if (policyRules_.isEmpty())
            return "<EMPTY POLICY>";

        StringBuffer buffer = new StringBuffer("Policy:\n");
        return recursePolicyToString(buffer, 0, false);
    }

    /**
     * A method for displaying only the rules used within the policy.
     * 
     * @return The policy in string format, minus the unused rules.
     */
    public String toOnlyUsedString() {
        if (policyRules_.isEmpty())
            return "<EMPTY POLICY>";

        StringBuffer buffer = new StringBuffer("Policy:\n");
        return recursePolicyToString(buffer, 0, true);
    }

    @Override
    public String toString() {
        return toNiceString();
    }

    /**
     * Checks if this modular policy is essentially equivalent another policy
     * (contains the same rules, ignoring sub-goals).
     * 
     * @param policy
     *            The policy to check for equivalencies.
     * @return True if the policy is equivalent (contaisn the same rules).
     */
    public boolean equivalentTo(RelationalPolicy policy) {
        if (this == policy)
            return true;
        Iterator<PolicyItem> policyIter = policy.getRules().iterator();
        for (PolicyItem pItem : policyRules_) {
            if (pItem instanceof RelationalRule) {
                // Check for equality
                if (!policyIter.hasNext() || !policyIter.next().equals(pItem))
                    return false;
            }
        }
        // If the other policy still has rules, return false.
        if (policyIter.hasNext())
            return false;
        return true;
    }

    public void clearPolicyRewards() {
        policyRewards_.clear();
    }
}