Java tutorial
/* * To change this license header, choose License Headers in Project Properties. * To change this template file, choose Tools | Templates * and open the template in the editor. */ package org.azrul.langmera; import io.vertx.core.Vertx; import io.vertx.core.shareddata.LocalMap; import java.io.Serializable; import java.util.List; import java.util.ArrayList; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Map; import java.util.Random; import java.util.function.Consumer; import java.util.logging.Level; import java.util.logging.Logger; import org.apache.commons.lang.SerializationUtils; import org.cfg4j.provider.ConfigurationProvider; /** * * @author Azrul */ public class QLearningAnalytics implements Analytics { //private static final Integer epsilon = 4; private static Random random; //private static final Double configuredAlpha = null; private static Long startTime; //private static Integer maxHistoryRetained = 3000; private String chartDesc = "Langemera: Adaptive Real-time Analytical Framework"; private static Map<String, List<Double>> traces = new HashMap<>(); private ConfigurationProvider config = null; private Logger logger = null; //private Logger logger =null; public QLearningAnalytics(Random random, Logger logger, ConfigurationProvider config) { QLearningAnalytics.random = random; startTime = (new Date()).getTime(); this.config = config; this.logger = logger; } @Override public void getDecision(DecisionRequest s, Vertx vertx, Consumer<DecisionResponse> responseAction) { int decisionCount = vertx.sharedData().getLocalMap("DECISION_REQUEST").size(); if (decisionCount % 10 <= config.getProperty("epsilon", Integer.class)) { getRandomDecision(s, vertx, responseAction); } else { getCalculatedDecision(s, vertx, responseAction); } } public void getCalculatedDecision(DecisionRequest req, Vertx vertx, Consumer<DecisionResponse> responseAction) { getDecisionPreCondition(req); LocalMap<String, Double> q = vertx.sharedData().getLocalMap("Q"); String keyWithMaxVal = null; Double maxVal = Double.NEGATIVE_INFINITY; for (String k : q.keySet()) { if (q.get(k) > maxVal) { maxVal = q.get(k); keyWithMaxVal = k; } } DecisionResponse resp = null; if (keyWithMaxVal != null) { String decision = keyWithMaxVal.split(":")[1]; resp = new DecisionResponse(); resp.setDecisionId(req.getDecisionId()); resp.setDecision(decision); resp.setQValue(maxVal); //save cache to be matched to feedback if (req != null) { vertx.sharedData().getLocalMap("DECISION_REQUEST").put(req.getDecisionId(), req); } if (resp != null) { vertx.sharedData().getLocalMap("DECISION_RESPONSE").put(req.getDecisionId(), resp); } responseAction.accept(resp); } else { getRandomDecision(req, vertx, responseAction); } } public void getRandomDecision(DecisionRequest req, Vertx vertx, Consumer<DecisionResponse> responseAction) { getDecisionPreCondition(req); Integer r = random.nextInt(req.getOptions().length); DecisionResponse resp = new DecisionResponse(); resp.setDecision(req.getOptions()[r]); resp.setDecisionId(req.getDecisionId()); //save cache to be matched to feedback if (req != null) { vertx.sharedData().getLocalMap("DECISION_REQUEST").put(req.getDecisionId(), req); } if (resp != null) { vertx.sharedData().getLocalMap("DECISION_RESPONSE").put(req.getDecisionId(), resp); } responseAction.accept(resp); } private void getDecisionPreCondition(DecisionRequest s) throws RuntimeException { if (s == null) { throw new RuntimeException("Decision Request is null"); } if (s.getOptions() == null) { throw new RuntimeException("Options array is null"); } if (s.getOptions().length == 0) { throw new RuntimeException("Options array is empty"); } } @Override public void learn(DecisionFeedback currentFeedback, Vertx vertx, Runnable responseAction) { LocalMap<String, DetailDecisionFeedback> decisionFeedbackMap = vertx.sharedData() .getLocalMap("DECISION_FEEDBACK"); LocalMap<String, DecisionRequest> decisionRequestMap = vertx.sharedData().getLocalMap("DECISION_REQUEST"); LocalMap<String, DecisionResponse> decisionResponseMap = vertx.sharedData() .getLocalMap("DECISION_RESPONSE"); LocalMap<String, Double> q = vertx.sharedData().getLocalMap("Q"); LocalMap<Long, String> trackers = vertx.sharedData().getLocalMap("FEEDBACK_TRACKER"); int feedbackCount = decisionFeedbackMap.size(); boolean skipLearning = false; if (decisionRequestMap.get(currentFeedback.getDecisionId()) == null) { skipLearning = true; } if (decisionResponseMap.get(currentFeedback.getDecisionId()) == null) { skipLearning = true; } if (skipLearning == false) { String context = decisionRequestMap.get(currentFeedback.getDecisionId()).getContext(); String decision = decisionResponseMap.get(currentFeedback.getDecisionId()).getDecision(); DetailDecisionFeedback detailFB = new DetailDecisionFeedback(); detailFB.setFeedback(currentFeedback); detailFB.setContext(context); detailFB.setDecision(decision); decisionFeedbackMap.put(currentFeedback.getDecisionId(), detailFB); Long trackerKey = (new Date()).getTime(); trackers.put(trackerKey, currentFeedback.getDecisionId()); int feedbackCountByDecision = 0; List<Double> rewards = new ArrayList<>(); for (DetailDecisionFeedback fb : decisionFeedbackMap.values()) { if (context.equals(decisionFeedbackMap.get(fb.getFeedback().getDecisionId()).getContext()) && decision .equals(decisionFeedbackMap.get(fb.getFeedback().getDecisionId()).getDecision())) { feedbackCountByDecision++; rewards.add(fb.getFeedback().getScore()); } } Double w = 0.0; Double alpha = config.getProperty("alpha", Double.class); //if no step parameter is configured, calculate it if (alpha == null) { alpha = 1.0 / (feedbackCountByDecision); } //non-stationary q-learning int i = 0; for (Double ri : rewards) { i++; w = w + alpha * (Math.pow(1 - alpha, feedbackCountByDecision - i)) * ri; } Double newQ = w; //System.out.println(feedbackCount+" Q:["+context + ":" + decision+"]"+newQ); //save what we learn if (newQ.isInfinite() || newQ.isNaN()) { //skip } else { String key = context + ":" + decision; q.put(key, newQ); } //Limit the number of history taken into account - prevents memory leak if (feedbackCount > config.getProperty("maxHistoryRetained", Integer.class)) { Long tk = Collections.min(trackers.keySet()); String decisionIDWithMinTracker = trackers.get(tk); decisionFeedbackMap.remove(decisionIDWithMinTracker); trackers.remove(tk); } //clear cached req/resp once the feedback has come back decisionRequestMap.remove(currentFeedback.getDecisionId()); decisionResponseMap.remove(currentFeedback.getDecisionId()); //Get maxQ Double maxQ = Double.NEGATIVE_INFINITY; String decisionWithMaxQ = null; for (String contextDecision : q.keySet()) { if (q.get(contextDecision) > maxQ) { decisionWithMaxQ = contextDecision; maxQ = q.get(contextDecision); } } //keep traces if (Boolean.TRUE.equals(config.getProperty("collect.traces", Boolean.class))) { Date now = new Date(); for (String contextDecision : q.keySet()) { List<Double> qtrace = traces.get(contextDecision); if (qtrace == null) { qtrace = new ArrayList<Double>(); qtrace.add(q.get(contextDecision)); traces.put(contextDecision, qtrace); } else { qtrace.add(q.get(contextDecision)); } String[] c = contextDecision.split(":"); Trace trace = new Trace(currentFeedback.getDecisionId(), c[0], q.get(contextDecision), maxQ, now, c[1], currentFeedback.getScore()); vertx.eventBus().publish("SAVE_TRACE_TO_TRACE", SerializationUtils.serialize((Serializable) trace)); } } // //put in in-memory DB // // // String[] c = decisionWithMaxQ.split(":"); // if (InMemoryDB.store.get(0)==null){ // List<Object> imContext = new ArrayList<Object>(); // imContext.add(c[0]); // InMemoryDB.store.add(0,imContext); // }else{ // InMemoryDB.store.get(0).add(c[0]); // } // // if (InMemoryDB.store.get(1)==null){ // List<Object> imDecision = new ArrayList<Object>(); // imDecision.add(c[1]); // InMemoryDB.store.add(1,imDecision); // }else{ // InMemoryDB.store.get(1).add(c[1]); // } // // if (InMemoryDB.store.get(2)==null){ // List<Object> imMaxQ = new ArrayList<Object>(); // imMaxQ.add(maxQ); // InMemoryDB.store.add(2,imMaxQ); // }else{ // InMemoryDB.store.get(2).add(maxQ); // } // // if (InMemoryDB.store.get(3)==null){ // List<Object> imTime= new ArrayList<Object>(); // imTime.add(new Date()); // InMemoryDB.store.add(3,imTime); // }else{ // InMemoryDB.store.get(3).add(new Date()); // } responseAction.run(); if (Boolean.TRUE.equals(currentFeedback.getTerminal())) { long delta = (new Date()).getTime() - startTime; System.out.println("Time taken to process " + feedbackCount + " msgs:" + delta + " ms"); System.out.println("Time taken per msg: " + (delta / feedbackCount) + " ms"); System.out .println("Msgs per s: " + ((1000.0 * (double) feedbackCount) / ((double) delta)) + " msgs"); if (Boolean.TRUE.equals(config.getProperty("collect.traces", Boolean.class)) && Boolean.TRUE.equals(config.getProperty("display.desktop.chart", Boolean.class))) { final LineChart demo = new LineChart(chartDesc, traces); demo.pack(); demo.setVisible(true); } } } else { logger.log(Level.WARNING, "Attempt to learn from a feedback with no corresponding request/response"); responseAction.run(); } // //select qmovies,qsports,qconcerts from // (select t1.qvalue as qsports,t1.decisionid from trace t1 where t1.decision='SPORTS' order by t1.decisiontime) as A1 // join (select t2.qvalue as qmovies,t2.decisionid from trace t2 where t2.decision='MOVIES' order by t2.decisiontime) as A2 on A1.decisionid = A2.decisionid // join (select t3.qvalue as qconcerts,t3.decisionid from trace t3 where t3.decision='CONCERTS' order by t3.decisiontime) as A3 on A1.decisionid = A3.decisionid } }