Java tutorial
/* * This file is part of the CERRLA algorithm * * CERRLA is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 3 of the License, or * (at your option) any later version. * * CERRLA is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with CERRLA. If not, see <http://www.gnu.org/licenses/>. */ /* * src/cerrla/ElitesData.java * Copyright (C) 2012 Samuel Sarjant */ package cerrla; import relationalFramework.RelationalRule; import util.MultiMap; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.SortedSet; import java.util.TreeSet; import org.apache.commons.math.stat.descriptive.moment.Mean; /** * A class for containing all of the data gathered about the elite solutions, * such as counts, positions, etc. * * @author Sam Sarjant */ public class ElitesData { /** The average positions for each slot within the elite policies. */ private Map<Slot, SlotData> slotData_; /** The elite values. */ private ArrayList<Double> elitesValues_; public ElitesData(int numElites) { slotData_ = new HashMap<Slot, SlotData>(); elitesValues_ = new ArrayList<Double>(numElites); } /** * Adds the weighted count of a slot in and increments the raw count. * * @param slot * The slot for which the data is being recorded. * @param weight * The weight to add to the slot [0-1]. */ public void addSlotCount(Slot slot, double weight) { SlotData sd = getSlotData(slot); sd.addCount(weight); sd.incrementRawCount(); } /** * Adds the weighted count of a rule in. * * @param ruleSlot * * @param rule * The rule for which he data is being recorded. * @param weight * The weight to add to the rule [0-1]. */ public void addRuleCount(Slot ruleSlot, RelationalRule rule, int weight) { Map<RelationalRule, Integer> ruleCounts = getSlotData(ruleSlot).getRuleCounts(); Integer oldWeight = ruleCounts.get(rule); if (oldWeight == null) oldWeight = 0; ruleCounts.put(rule, oldWeight + weight); } /** * Adds a relative ordering value to the slot ordering. * * @param slot * The slot for which the data is being recorded. * @param relValue * The relative ordering value of the slot [0-1]. */ public void addSlotOrdering(Slot slot, double relValue) { SlotData sd = getSlotData(slot); sd.addOrdering(relValue); } /** * Sets the slot numeracy mean of the slot to a value. * * @param slot * The slot for which the data is being recorded. * @param mean * The mean of the slot usage. */ public void setUsageStats(Slot slot, double mean) { SlotData slotData = getSlotData(slot); slotData.setMean(mean); } /** * Gets a slot's average position. * * @param slot The slot. * @return The average position of the slot (default 0.5). */ public Double getSlotPosition(Slot slot) { if (slotData_.containsKey(slot)) return slotData_.get(slot).getAverageOrdering(); return null; } /** * Gets the rule counts. * * @param slot * The slot to get the rule counts for. * @return A mapping of rules to their weighted counts. */ public Map<RelationalRule, Integer> getSlotRuleCounts(Slot slot) { SlotData slotData = slotData_.get(slot); if (slotData == null) return null; return slotData.getRuleCounts(); } /** * Gets the count for a slot. * * @param slot * The slot being searched for. * @return The counts it has, or 0. */ public double getSlotCount(Slot slot) { if (slotData_.containsKey(slot)) return slotData_.get(slot).getCount(); return 0; } /** * Gets the slot numeracy value. * * @param slot * The slot to get the data for. * @return The numeracy value, or 0 if not recorded. */ public double getSlotNumeracyMean(Slot slot) { if (slotData_.containsKey(slot)) return slotData_.get(slot).getNumeracy(); return 0; } /** * Gets or initialises the slot data. * * @param slot * The slot to get the data for. * @return The data for the slot, either new or existing. */ private SlotData getSlotData(Slot slot) { if (!slotData_.containsKey(slot)) slotData_.put(slot, new SlotData(slot.size())); return slotData_.get(slot); } /** * Get the best elite value. * * @return The value of the best elite. */ public Double getMaxEliteValue() { if (elitesValues_.isEmpty()) return null; return elitesValues_.get(0); } /** * Get the average value of the elite samples. * * @return The average value of the elites. */ public Double getMeanEliteValue() { if (elitesValues_.isEmpty()) return null; double[] values = new double[elitesValues_.size()]; int i = 0; for (Double val : elitesValues_) values[i++] = val; Mean m = new Mean(); return m.evaluate(values); } @Override public String toString() { StringBuffer buffer = new StringBuffer(); buffer.append("Slot counts: \n"); MultiMap<Integer, Slot> orderedMap = MultiMap.createSortedSetMultiMap(); SortedSet<Integer> keys = new TreeSet<Integer>(); for (Slot slot : slotData_.keySet()) { int count = (int) slotData_.get(slot).getCount(); orderedMap.put(count, slot); keys.add(count); } for (Integer count : keys) { for (Slot slot : orderedMap.get(count)) { SlotData slotData = slotData_.get(slot); if (slot.getSlotSplitFacts().isEmpty()) buffer.append("\tSlot " + slot.getAction() + ":\n" + slotData + "\n"); else buffer.append("\tSlot " + slot.getSlotSplitFacts() + " => " + slot.getAction() + ":\n" + slotData_.get(slot) + "\n"); } } return buffer.toString(); } /** * A class for holding slot data. * * @author Sam Sarjant */ private class SlotData { /** The number of times this slot is present in the elite policies. */ private double count_; /** The raw count of the number of slots in the elite policies. */ private int rawCount_; /** The average position of the slot in the policies. */ private Double position_ = null; /** The average number of slots per policy. */ private double mean_; /** The rule counts for this slot. */ private Map<RelationalRule, Integer> ruleCounts_; public SlotData(int numRules) { ruleCounts_ = new HashMap<RelationalRule, Integer>(numRules); } public void addCount(double weight) { count_ += weight; } public Map<RelationalRule, Integer> getRuleCounts() { return ruleCounts_; } public void setMean(double mean) { mean_ = mean; } public double getCount() { return count_; } public void addOrdering(double relValue) { if (position_ == null) position_ = 0d; position_ += relValue; } public void incrementRawCount() { rawCount_++; } public double getAverageOrdering() { if (position_ == null) return 0.5d; return position_ / rawCount_; } public double getNumeracy() { return mean_; } @Override public String toString() { StringBuffer buffer = new StringBuffer("\tCount: " + count_); buffer.append("\tRaw Count: " + rawCount_); buffer.append("\tPosition: " + getAverageOrdering()); buffer.append("\tNumeracy: " + mean_); // Rule counts for the slot. buffer.append("\n\tRule counts: \n"); MultiMap<Integer, RelationalRule> orderedMap = MultiMap.createSortedSetMultiMap(); SortedSet<Integer> keys = new TreeSet<Integer>(); for (RelationalRule rule : ruleCounts_.keySet()) { orderedMap.put(ruleCounts_.get(rule), rule); keys.add(ruleCounts_.get(rule)); } for (Integer count : keys) { for (RelationalRule rule : orderedMap.get(count)) buffer.append("\t\t" + rule + ": " + count + "\n"); } return buffer.toString(); } } public void noteSampleValue(double value) { elitesValues_.add(value); } /** * If all elite values are the same value, then set all probabilities to * one. */ public void setEqualValues() { for (SlotData sd : slotData_.values()) { if (sd.mean_ > 0) { sd.mean_ = 1; sd.count_ = 1; } } } }