Java tutorial
/* LICENSE Copyright (c) 2013-2016, Jesse Hostetler (jessehostetler@gmail.com) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /** * */ package edu.oregonstate.eecs.mcplan.domains.blackjack; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.ArrayList; import java.util.Arrays; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.CholeskyDecomposition; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.random.MersenneTwister; import org.apache.commons.math3.random.RandomGenerator; import weka.classifiers.Classifier; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.core.SerializationHelper; import weka.core.converters.ArffSaver; import weka.core.converters.Saver; import edu.oregonstate.eecs.mcplan.ActionGenerator; import edu.oregonstate.eecs.mcplan.FactoredRepresentation; import edu.oregonstate.eecs.mcplan.JointAction; import edu.oregonstate.eecs.mcplan.Pair; import edu.oregonstate.eecs.mcplan.Policy; import edu.oregonstate.eecs.mcplan.RandomPolicy; import edu.oregonstate.eecs.mcplan.Representer; import edu.oregonstate.eecs.mcplan.abstraction.ClusterAbstraction; import edu.oregonstate.eecs.mcplan.abstraction.PairwiseSimilarityRepresenter; import edu.oregonstate.eecs.mcplan.domains.cards.Deck; import edu.oregonstate.eecs.mcplan.domains.cards.InfiniteDeck; import edu.oregonstate.eecs.mcplan.ml.InformationTheoreticMetricLearner; import edu.oregonstate.eecs.mcplan.ml.MetricConstrainedKMeans; import edu.oregonstate.eecs.mcplan.ml.VoronoiClassifier; import edu.oregonstate.eecs.mcplan.search.BackupRule; import edu.oregonstate.eecs.mcplan.search.BackupRules; import edu.oregonstate.eecs.mcplan.search.DefaultMctsVisitor; import edu.oregonstate.eecs.mcplan.search.GameTree; import edu.oregonstate.eecs.mcplan.search.GameTreeFactory; import edu.oregonstate.eecs.mcplan.search.MctsVisitor; import edu.oregonstate.eecs.mcplan.search.SearchPolicy; import edu.oregonstate.eecs.mcplan.search.UctSearch; import edu.oregonstate.eecs.mcplan.sim.Episode; import edu.oregonstate.eecs.mcplan.sim.EpisodeListener; import edu.oregonstate.eecs.mcplan.util.Csv; import edu.oregonstate.eecs.mcplan.util.Fn; import edu.oregonstate.eecs.mcplan.util.MeanVarianceAccumulator; import gnu.trove.list.TIntList; import gnu.trove.list.array.TIntArrayList; import gnu.trove.map.TIntObjectMap; import gnu.trove.map.hash.TIntObjectHashMap; import gnu.trove.procedure.TIntProcedure; import hr.irb.fastRandomForest.FastRandomForest; /** * @author jhostetler * */ public class AbstractionDiscovery { public static class IdentityRepresenter implements Representer<BlackjackState, BlackjackStateToken> { @Override public Representer<BlackjackState, BlackjackStateToken> create() { return new IdentityRepresenter(); } @Override public BlackjackStateToken encode(final BlackjackState s) { return new BlackjackStateToken(s); } @Override public String toString() { return "flat"; } } public static class SolvedStateAccumulator<X extends FactoredRepresentation<BlackjackState>> implements EpisodeListener<BlackjackState, BlackjackAction> { public ArrayList<X> states_ = new ArrayList<X>(); public ArrayList<RealVector> Phi_ = new ArrayList<RealVector>(); public ArrayList<BlackjackAction> actions_ = new ArrayList<BlackjackAction>(); private final Representer<BlackjackState, X> repr_; private X x_ = null; public SolvedStateAccumulator(final Representer<BlackjackState, X> repr) { repr_ = repr; } @Override public <P extends Policy<BlackjackState, JointAction<BlackjackAction>>> void startState( final BlackjackState s, final P pi) { x_ = repr_.encode(s); } @Override public void preGetAction() { } @Override public void postGetAction(final JointAction<BlackjackAction> a) { states_.add(x_); Phi_.add(new ArrayRealVector(x_.phi())); actions_.add(a.get(0).create()); } @Override public void onActionsTaken(final BlackjackState sprime) { x_ = repr_.encode(sprime); } @Override public void endState(final BlackjackState s) { } } public static class UnlabeledStateAccumulator<X extends FactoredRepresentation<BlackjackState>> extends DefaultMctsVisitor<BlackjackState, X, BlackjackAction> { public ArrayList<RealVector> Phi_ = new ArrayList<RealVector>(); private final Representer<BlackjackState, X> repr_; private RealVector s_ = null; public UnlabeledStateAccumulator(final Representer<BlackjackState, X> repr) { repr_ = repr; } @Override public void treeAction(final JointAction<BlackjackAction> a, final BlackjackState sprime, final int[] next_turn) { // This has the effect of adding states only if an action was // chosen within the tree for that state. It will *not* add the // root state, since that is labeled and will be added by // SolvedStateAccumulator if (s_ != null) { Phi_.add(s_); } s_ = new ArrayRealVector(repr_.encode(sprime).phi()); } } // ----------------------------------------------------------------------- private static MetricConstrainedKMeans makeClustering(final RandomGenerator rng, final RealMatrix A0, final ArrayList<RealVector> XL, final ArrayList<BlackjackAction> y, final ArrayList<RealVector> XU, final boolean with_metric_learning) { final int K = 4; final int d = XL.get(0).getDimension(); //52 + 52; final double u = 4.0; //1.0; final double ell = 16.0; //2.0; final double gamma = 1.0; final ArrayList<RealVector> X = new ArrayList<RealVector>(); X.addAll(XL); X.addAll(XU); final TIntObjectMap<Pair<int[], double[]>> M = new TIntObjectHashMap<Pair<int[], double[]>>(); final TIntObjectMap<Pair<int[], double[]>> C = new TIntObjectHashMap<Pair<int[], double[]>>(); for (int i = 0; i < XL.size(); ++i) { final TIntList m = new TIntArrayList(); final TIntList c = new TIntArrayList(); for (int j = i + 1; j < XL.size(); ++j) { if (y.get(i).equals(y.get(j))) { m.add(j); } else { c.add(j); } } M.put(i, Pair.makePair(m.toArray(), Fn.repeat(1.0, m.size()))); C.put(i, Pair.makePair(c.toArray(), Fn.repeat(1.0, c.size()))); } final ArrayList<int[]> S = new ArrayList<int[]>(); M.forEachKey(new TIntProcedure() { @Override public boolean execute(final int i) { final Pair<int[], double[]> p = M.get(i); if (p != null) { for (final int j : p.first) { S.add(new int[] { i, j }); } } return true; } }); final ArrayList<int[]> D = new ArrayList<int[]>(); C.forEachKey(new TIntProcedure() { @Override public boolean execute(final int i) { final Pair<int[], double[]> p = C.get(i); if (p != null) { for (final int j : p.first) { D.add(new int[] { i, j }); } } return true; } }); final RealMatrix A; if (with_metric_learning) { final InformationTheoreticMetricLearner itml = new InformationTheoreticMetricLearner(X, S, D, u, ell, A0, gamma, rng); itml.run(); A = itml.A(); } else { A = MatrixUtils.createRealIdentityMatrix(d); } final MetricConstrainedKMeans kmeans = new MetricConstrainedKMeans(K, d, X, A, M, C, rng); kmeans.run(); return kmeans; } private static ArrayList<RealVector> enumerateStates(final BlackjackParameters params) { // FIXME: The proliferation of different state representations is // alarming. I'm using them here for convenience, but they are too // subtly different to be intermixed like this safely! final ArrayList<RealVector> states = new ArrayList<RealVector>(); final BlackjackStateSpace ss = new BlackjackStateSpace(params); for (final BlackjackMdpState s : Fn.in(ss.generator())) { if (!s.player_passed) { states.add(new ArrayRealVector( HandValueAbstraction.makePhi(params, s.player_value, s.player_high_aces, s.dealer_value))); } } return states; } private static void writeClustering(final MetricConstrainedKMeans kmeans, final File root, final int iter, final BlackjackParameters params, final String[][] hard_actions, final String[][] soft_actions) throws FileNotFoundException { Csv.write(new PrintStream(new File(root, "M" + iter + ".csv")), kmeans.metric); { final Csv.Writer writer = new Csv.Writer(new PrintStream(new File(root, "mu" + iter + ".csv"))); for (int i = 0; i < kmeans.d; ++i) { for (int j = 0; j < kmeans.k; ++j) { writer.cell(kmeans.mu()[j].getEntry(i)); } writer.newline(); } } // Lt.operate( x ) maps x to the space defined by the metric final RealMatrix Lt = new CholeskyDecomposition(kmeans.metric).getLT(); { final Csv.Writer writer = new Csv.Writer(new PrintStream(new File(root, "X" + iter + ".csv"))); writer.cell("cluster").cell("label").cell("x1").cell("x2").cell("x3").cell("Ax1").cell("Ax2") .cell("Ax3").newline(); for (int cluster = 0; cluster < kmeans.k; ++cluster) { for (int i = 0; i < kmeans.N; ++i) { if (kmeans.assignments()[i] == cluster) { writer.cell(cluster); final RealVector phi = kmeans.X_.get(i); //Phi.get( i ); final int pv = (int) phi.getEntry(0); final int paces = (int) phi.getEntry(1); final int dv = (int) phi.getEntry(2); if (paces > 0) { writer.cell(soft_actions[pv - params.soft_hand_min][dv - params.dealer_showing_min]); } else { writer.cell(hard_actions[pv - params.hard_hand_min][dv - params.dealer_showing_min]); } for (int j = 0; j < phi.getDimension(); ++j) { writer.cell(phi.getEntry(j)); } final RealVector trans = Lt.operate(phi); for (int j = 0; j < trans.getDimension(); ++j) { writer.cell(trans.getEntry(j)); } writer.newline(); } } } } } private static <X extends FactoredRepresentation<BlackjackState>> Instances makeTrainingSet( final SolvedStateAccumulator<X> acc, final ArrayList<Attribute> attributes, final int iter) { final int[] num_instances = new int[2]; final ArrayList<Instance> negative = new ArrayList<Instance>(); final ArrayList<Instance> positive = new ArrayList<Instance>(); final ArrayList<String> nominal = new ArrayList<String>(); nominal.add("0"); nominal.add("1"); attributes.add(new Attribute("__label__", nominal)); final int d = attributes.size() - 1; // Minus 1 for label for (int i = 0; i < acc.Phi_.size(); ++i) { final double[] phi_i = acc.Phi_.get(i).toArray(); for (int j = i + 1; j < acc.Phi_.size(); ++j) { final double[] phi_j = acc.Phi_.get(j).toArray(); final double[] phi_labeled = new double[d + 1]; for (int k = 0; k < d; ++k) { phi_labeled[k] = Math.abs(phi_i[k] - phi_j[k]); } final int label; if (acc.actions_.get(i).equals(acc.actions_.get(j))) { label = 1; } else { label = 0; } final double weight = 1.0; // TODO: Weights? final String label_string = Integer.toString(label); phi_labeled[d] = label; //attributes.get( label_index ).indexOfValue( label_string ); num_instances[label] += 1; final Instance instance = new DenseInstance(weight, phi_labeled); if (label == 0) { negative.add(instance); } else { positive.add(instance); } } } System.out.println("num_instances = " + Arrays.toString(num_instances)); final Instances x = new Instances("train" + iter, attributes, negative.size() + positive.size()); x.setClassIndex(d); x.addAll(negative); x.addAll(positive); return x; } private static Classifier makeClassifier(final Instances train) { try { final FastRandomForest rf = new FastRandomForest(); // final Classifier rf = new J48(); rf.buildClassifier(train); return rf; } catch (final Exception ex) { throw new RuntimeException(ex); } } private static void writeDataset(final File root, final Instances x) { final File dataset_file = new File(root, x.relationName() + ".arff"); final Saver saver = new ArffSaver(); try { saver.setFile(dataset_file); saver.setInstances(x); saver.writeBatch(); } catch (final IOException ex) { throw new RuntimeException(ex); } } private static final RandomGenerator rng = new MersenneTwister(42); private static <X extends FactoredRepresentation<BlackjackState>, R extends Representer<BlackjackState, X>> void runExperiment( final BlackjackParameters params, final int Niterations, final int Ntrain_games, final int Ntrain_episodes, final int Ntest_games, final int Ntest_episodes, final File root) throws Exception { final BlackjackAggregator repr = new BlackjackAggregator(); final BlackjackMdp mdp = new BlackjackMdp(params); System.out.println("Solving MDP"); final Pair<String[][], String[][]> soln = mdp.solve(); final String[][] hard_actions = soln.first; final String[][] soft_actions = soln.second; System.out.println("****************************************"); System.out.println("game = " + params.max_score + " x (" + Ntrain_games + "(" + Ntrain_episodes + ")" + " / " + Ntest_games + "(" + Ntest_episodes + ")) " + ": " + repr); final Csv.Writer data_out = new Csv.Writer(new PrintStream(new File(root, "data.csv"))); data_out.cell("abstraction").cell("game").cell("iteration").cell("Ntrain_games").cell("Ntrain_episodes") .cell("Ntest_games").cell("Ntest_episodes").cell("mean").cell("var").cell("conf").newline(); final ActionGenerator<BlackjackState, JointAction<BlackjackAction>> action_gen = new BlackjackJointActionGenerator( 1); final Policy<BlackjackState, JointAction<BlackjackAction>> rollout_policy = new RandomPolicy<BlackjackState, JointAction<BlackjackAction>>( 0 /*Player*/, rng.nextInt(), action_gen.create()); final double c = 1.0; final int rollout_width = 1; final int rollout_depth = 1; // Optimistic default value final double[] default_value = new double[] { 1.0 }; Representer<BlackjackState, ClusterAbstraction<BlackjackState>> Crepr = new TrivialClusterRepresenter( params, mdp.S()); // NOTE: In the Blackjack domain, we can easily enumerate all // legal states, so I'm punting the issue of how to collect them // properly during search. In reality, there's a question of // whether we should be weighting them, e.g. by their reachability. final ArrayList<RealVector> Phi = enumerateStates(params); RealMatrix A0 = MatrixUtils.createRealIdentityMatrix(Phi.get(0).getDimension()); for (int iter = 0; iter < Niterations; ++iter) { System.out.println("Iteration " + iter); // final UnlabeledStateAccumulator<ClusterAbstraction<BlackjackState>> train_visitor // = new UnlabeledStateAccumulator<ClusterAbstraction<BlackjackState>>( Crepr.create() ); final MctsVisitor<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction> train_visitor = new DefaultMctsVisitor<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction>(); final BackupRule<ClusterAbstraction<BlackjackState>, BlackjackAction> train_backup = BackupRule .<ClusterAbstraction<BlackjackState>, BlackjackAction>MaxQ(); // Gather training examples System.out.println("Gathering training examples..."); final SolvedStateAccumulator<HandValueAbstraction> acc = new SolvedStateAccumulator<HandValueAbstraction>( repr); for (int i = 0; i < Ntrain_games; ++i) { if (i % 100000 == 0) { System.out.println("Episode " + i); } final Deck deck = new InfiniteDeck(); final BlackjackSimulator sim = new BlackjackSimulator(deck, 1, params); final GameTreeFactory<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction> factory = new UctSearch.Factory<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction>( sim, Crepr.create(), action_gen.create(), c, Ntrain_episodes, rng, rollout_policy, rollout_width, rollout_depth, train_backup, default_value); final SearchPolicy<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction> search_policy = new SearchPolicy<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction>( factory, train_visitor, null) { @Override protected JointAction<BlackjackAction> selectAction( final GameTree<ClusterAbstraction<BlackjackState>, BlackjackAction> tree) { return BackupRules.MaxAction(tree.root()).a(); } @Override public int hashCode() { return System.identityHashCode(this); } @Override public boolean equals(final Object that) { return this == that; } }; final Episode<BlackjackState, BlackjackAction> episode = new Episode<BlackjackState, BlackjackAction>( sim, search_policy); episode.addListener(acc); episode.run(); } // Train classifier System.out.println("Training classifier..."); final String algorithm = "kmeans"; //"rf"; final boolean with_metric_learning = true; if ("kmeans".equals(algorithm)) { // final ArrayList<RealVector> Phi = train_visitor.Phi_; final MetricConstrainedKMeans kmeans = makeClustering(rng, A0, acc.Phi_, acc.actions_, Phi, with_metric_learning); final VoronoiClassifier classifier = new VoronoiClassifier(kmeans.mu()) { @Override protected double distance(final RealVector x1, final RealVector x2) { return kmeans.distance(x1, x2); } }; // Update reference matrix. This has the effect of keeping some of // the information from previous training episodes. A0 = kmeans.metric.copy(); writeClustering(kmeans, root, iter, params, hard_actions, soft_actions); Crepr = new ClusterRepresenter(classifier, repr.create()); } else if ("rf".equals(algorithm)) { final Instances train = makeTrainingSet(acc, HandValueAbstraction.makeAttributes(params), iter); writeDataset(root, train); final Classifier classifier = makeClassifier(train); SerializationHelper.write(new File(root, "rf" + iter + ".model").getAbsolutePath(), classifier); Crepr = new PairwiseSimilarityRepresenter<BlackjackState, HandValueAbstraction>(repr.create(), new Instances(train), classifier); } // Test System.out.println("Testing..."); final MctsVisitor<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction> test_visitor = new DefaultMctsVisitor<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction>(); final BackupRule<ClusterAbstraction<BlackjackState>, BlackjackAction> test_backup = BackupRule .<ClusterAbstraction<BlackjackState>, BlackjackAction>MaxQ(); final MeanVarianceAccumulator ret = new MeanVarianceAccumulator(); for (int i = 0; i < Ntest_games; ++i) { if (i % 10000 == 0) { System.out.println("Episode " + i); } final Deck deck = new InfiniteDeck(); final BlackjackSimulator sim = new BlackjackSimulator(deck, 1, params); final GameTreeFactory<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction> factory = new UctSearch.Factory<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction>( sim, Crepr.create(), action_gen.create(), c, Ntest_episodes, rng, rollout_policy, rollout_width, rollout_depth, test_backup, default_value); final SearchPolicy<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction> search_policy = new SearchPolicy<BlackjackState, ClusterAbstraction<BlackjackState>, BlackjackAction>( factory, test_visitor, null) { @Override protected JointAction<BlackjackAction> selectAction( final GameTree<ClusterAbstraction<BlackjackState>, BlackjackAction> tree) { return BackupRules.MaxAction(tree.root()).a(); } @Override public int hashCode() { return System.identityHashCode(this); } @Override public boolean equals(final Object that) { return this == that; } }; final Episode<BlackjackState, BlackjackAction> episode = new Episode<BlackjackState, BlackjackAction>( sim, search_policy); episode.run(); // System.out.println( sim.state().token().toString() ); // System.out.println( "Reward: " + sim.reward()[0] ); ret.add(sim.reward()[0]); } System.out.println("****************************************"); System.out.println("Average return: " + ret.mean()); System.out.println("Return variance: " + ret.variance()); final double conf = 0.975 * ret.variance() / Math.sqrt(Ntest_games); System.out.println("Confidence: " + conf); System.out.println(); // data_out.println( "abstraction,game,iterations,Ntrain_games,Ntrain_episodes,Ntest_games,Ntest_episodes,mean,var,conf" ); data_out.cell(repr).cell(params.max_score).cell(iter).cell(Ntrain_games).cell(Ntrain_episodes) .cell(Ntest_games).cell(Ntest_episodes).cell(ret.mean()).cell(ret.variance()).cell(conf) .newline(); } } /** * @param args * @throws FileNotFoundException */ public static void main(final String[] args) throws Exception { final BlackjackParameters params = new BlackjackParameters(); final int Niterations = 4; final int Ntrain_games = 100; final int Ntest_games = 100000; final int Ntrain_episodes = 2048; final int Ntest_episodes = 256; final File root = new File("discovery/sandbox"); runExperiment(params, Niterations, Ntrain_games, Ntrain_episodes, Ntest_games, Ntest_episodes, root); } }