Java tutorial
/******************************************************************************* * This file is part of PDTTA, a library for learning Probabilistic deterministic timed-transition Automata. * Copyright (C) 2013-2015 Timo Klerx * * PDTTA is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. * * PDTTA is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. * * You should have received a copy of the GNU General Public License along with PDTTA. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package de.upb.timok.models; import gnu.trove.list.TDoubleList; import gnu.trove.list.TIntList; import gnu.trove.list.array.TDoubleArrayList; import gnu.trove.list.array.TIntArrayList; import gnu.trove.map.TIntDoubleMap; import gnu.trove.map.hash.TIntDoubleHashMap; import gnu.trove.set.hash.TIntHashSet; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.io.Serializable; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import jsat.distributions.Distribution; import org.apache.commons.math3.util.Precision; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import de.upb.timok.constants.AnomalyInsertionType; import de.upb.timok.constants.ClassLabel; import de.upb.timok.interfaces.AutomatonModel; import de.upb.timok.structure.AbnormalTransition; import de.upb.timok.structure.TimedSequence; import de.upb.timok.structure.Transition; import de.upb.timok.structure.ZeroProbTransition; import de.upb.timok.utils.MasterSeed; /** * A PDTTA with two thresholds for anomaly detection (aggregated event and aggregated time probability). * * @author timok * */ public class PDTTA implements AutomatonModel, Serializable { private static final double ANOMALY_3_CHANGE_RATE = 0.5; private static final double ANOMALY_4_CHANGE_RATE = 0.1; protected static final int START_STATE = 0; /** * */ private static final long serialVersionUID = 3017416753740710943L; // TODO implement PDTTA as an extension of a PDFA transient private static Logger logger = LoggerFactory.getLogger(PDTTA.class); transient private static final double ANOMALY_TYPE_TWO_P_1 = 0.7; transient private static final double ANOMALY_TYPE_TWO_P_2 = 0.9; protected static final double NO_TRANSITION_PROBABILITY = 0; private static final boolean DELETE_NO_TIME_INFORMATION_TRANSITIONS = true; TIntHashSet alphabet = new TIntHashSet(); Set<Transition> transitions = new HashSet<>(); TIntDoubleMap finalStateProbabilities = new TIntDoubleHashMap(); Map<ZeroProbTransition, Distribution> transitionDistributions = null; public Map<ZeroProbTransition, Distribution> getTransitionDistributions() { return transitionDistributions; } public void setTransitionDistributions(Map<ZeroProbTransition, Distribution> transitionDistributions) { this.transitionDistributions = transitionDistributions; if (!isConsistent()) { restoreConsistency(); } } protected void restoreConsistency() { logger.warn("Model is not consistent; restoring consistency..."); if (DELETE_NO_TIME_INFORMATION_TRANSITIONS) { deleteIrrelevantTransitions(); fixProbabilities(); } } private boolean isConsistent() { if (transitions.size() != transitionDistributions.size()) { logger.warn("transitions and transitionDistributions must be of same size! {}!={}", transitions.size(), transitionDistributions.size()); return false; } return finalStateProbabilities.keySet().forEach(state -> checkProbability(state)); } private boolean checkProbability(int state) { final List<Transition> outgoingTransitions = getTransitions(state, true); final double sum = outgoingTransitions.stream().mapToDouble(t -> t.getProbability()).sum(); return Precision.equals(sum, 1); } private void deleteIrrelevantTransitions() { logger.info("There are {} many transitions before removing irrelevant ones", transitions.size()); // there may be more transitions than transitionDistributions transitions.removeIf(t -> !transitionDistributions.containsKey(t.toZeroProbTransition())); fixProbabilities(); if (transitions.size() != transitionDistributions.size()) { logger.error( "This should never happen because trainsitions.size() and transitionDistributions.size() should be equal now, but are not! {}!={}", transitions.size(), transitionDistributions.size()); } logger.info("There are {} many transitions after removing irrelevant ones", transitions.size()); } void fixProbabilities() { finalStateProbabilities.keySet().forEach(state -> fixProbability(state)); } boolean fixProbability(int state) { // TODO use Fraction or BigFraction here! final List<Transition> outgoingTransitions = getTransitions(state, true); final double sum = outgoingTransitions.stream().mapToDouble(t -> t.getProbability()).sum(); logger.info("Sum of transition probabilities is {}", sum); // divide every probability by the sum of probabilities s.t. they sum up to 1 outgoingTransitions.forEach(t -> t.setProbability(t.getProbability() / sum)); final double newSum = outgoingTransitions.stream().mapToDouble(t -> t.getProbability()).sum(); logger.info("Corrected sum of transition probabilities is {}", newSum); if (!Precision.equals(newSum, 1.0)) { throw new IllegalStateException(); } return true; } protected PDTTA() { } public PDTTA(Path trebaPath) throws IOException { final BufferedReader inputReader = Files.newBufferedReader(trebaPath, StandardCharsets.UTF_8); String line = ""; // 172 172 3 0,013888888888888892 // from state ; to state ; symbol ; probability while ((line = inputReader.readLine()) != null) { final String[] lineSplit = line.split(" "); if (lineSplit.length == 4) { final int fromState = Integer.parseInt(lineSplit[0]); final int toState = Integer.parseInt(lineSplit[1]); final int symbol = Integer.parseInt(lineSplit[2]); final double probability = Double.parseDouble(lineSplit[3]); addTransition(fromState, toState, symbol, probability); } else if (lineSplit.length == 2) { final int state = Integer.parseInt(lineSplit[0]); final double finalProb = Double.parseDouble(lineSplit[1]); addFinalState(state, finalProb); } } } public int getTransitionCount() { return transitions.size(); } public Transition addTransition(int fromState, int toState, int symbol, double probability) { addState(fromState); addState(toState); alphabet.add(symbol); final Transition t = new Transition(fromState, toState, symbol, probability); transitions.add(t); return t; } public Transition addAbnormalTransition(int fromState, int toState, int symbol, double probability, AnomalyInsertionType anomalyType) { addState(fromState); addState(toState); alphabet.add(symbol); final Transition t = new AbnormalTransition(fromState, toState, symbol, probability, anomalyType); transitions.add(t); return t; } protected void addState(int state) { if (!finalStateProbabilities.containsKey(state)) { // finalStateProbabilities is also the set of states. so add the state to this set with a probability of zero addFinalState(state, NO_TRANSITION_PROBABILITY); } } public void toGraphvizFile(Path graphvizResult, boolean compressed) throws IOException { final BufferedWriter writer = Files.newBufferedWriter(graphvizResult, StandardCharsets.UTF_8); writer.write("digraph G {\n"); // start states writer.write("qi [shape = point ];"); // write states for (final int state : finalStateProbabilities.keys()) { writer.write(Integer.toString(state)); writer.write(" [shape="); final double finalProb = finalStateProbabilities.get(state); if (finalProb > 0 || (compressed && finalProb > 0.01)) { writer.write("double"); } writer.write("circle, label=\""); writer.write(Integer.toString(state)); if (finalProb > 0 || (compressed && finalProb > 0.01)) { writer.write(" - P= "); writer.write(Double.toString(Precision.round(finalProb, 2))); } writer.write("\"];\n"); } writer.write("qi -> 0;"); // write transitions for (final Transition t : transitions) { if (compressed && t.getProbability() <= 0.01) { continue; } // 0 -> 0 [label=0.06]; writer.write(Integer.toString(t.getFromState())); writer.write(" -> "); writer.write(Integer.toString(t.getToState())); writer.write(" [label=\""); writer.write(Integer.toString(t.getSymbol())); if (t.getProbability() > 0) { writer.write(" p="); writer.write(Double.toString(Precision.round(t.getProbability(), 2))); } writer.write("\"];\n"); } writer.write("}"); writer.close(); } public void addFinalState(int state, double probability) { finalStateProbabilities.put(state, probability); } public double getFinalStateProbability(int state) { return finalStateProbabilities.get(state); } public double getTransitionProbability(int fromState, int toState, int symbol) { for (final Transition t : transitions) { if (t.getFromState() == fromState && t.getToState() == toState && t.getSymbol() == symbol) { return t.getProbability(); } } return NO_TRANSITION_PROBABILITY; } public Transition getTransition(int currentState, int event) { Transition result = null; for (final Transition t : transitions) { if (t.getFromState() == currentState && t.getSymbol() == event) { if (result != null) { logger.error( "Found more than one transition for state " + currentState + " and event " + event); } result = t; } } // if (result == null) { // System.err.println("Found no transition for state " + currentState + " and event " + event); // } return result; } protected void removeTransition(Transition t) { final boolean wasRemoved = transitions.remove(t); if (!wasRemoved) { logger.warn("Tried to remove a non existing transition={}", t); } } private static final int MAX_SEQUENCE_LENGTH = 1000; @Deprecated public TimedSequence createAbnormalEventSequence(Random mutation) { // choose very unlikely sequences final TIntList eventList = new TIntArrayList(); final TDoubleList timeList = new TDoubleArrayList(); boolean choseFinalState = false; int currentState = 0; while (!choseFinalState) { final List<Transition> possibleTransitions = getTransitions(currentState, true); possibleTransitions.sort((o1, o2) -> Double.compare(o1.getProbability(), o2.getProbability())); int listIndex = 3; if (possibleTransitions.size() <= listIndex) { listIndex = possibleTransitions.size() - 1; } int tempListIndex = Math.min(3, possibleTransitions.size() - 1); if (tempListIndex != listIndex) { throw new IllegalStateException(); } final List<Transition> topThree = possibleTransitions.subList(0, listIndex); final double randomValue = mutation.nextDouble(); int chosenTransitionIndex = -1; if (randomValue <= ANOMALY_TYPE_TWO_P_1) { chosenTransitionIndex = 0; } else if (randomValue > ANOMALY_TYPE_TWO_P_1 && randomValue < ANOMALY_TYPE_TWO_P_2) { chosenTransitionIndex = 1; } else { chosenTransitionIndex = 2; } int indexToTake = chosenTransitionIndex; if (indexToTake >= topThree.size()) { indexToTake = topThree.size() - 1; } tempListIndex = Math.min(chosenTransitionIndex, topThree.size() - 1); if (tempListIndex != indexToTake) { throw new IllegalStateException(); } final Transition chosenTransition = topThree.get(indexToTake); if (chosenTransition.isStopTraversingTransition() || eventList.size() > MAX_SEQUENCE_LENGTH) { choseFinalState = true; } else { currentState = chosenTransition.getToState(); final Distribution d = transitionDistributions.get(chosenTransition.toZeroProbTransition()); if (d == null) { // just do it again with other random sampling return createAbnormalEventSequence(mutation); } final double timeValue = d.sample(1, mutation)[0]; eventList.add(chosenTransition.getSymbol()); timeList.add(timeValue); } } return new TimedSequence(eventList, timeList, ClassLabel.ANOMALY); } /** * Returns all outgoing probabilities from the given state * * @param currentState * the given state * @param includeStoppingTransition * whether to include final transition probabilities * @return */ protected List<Transition> getTransitions(int currentState, boolean includeStoppingTransition) { final List<Transition> result = new ArrayList<>(); for (final Transition t : transitions) { if (t.getFromState() == currentState) { result.add(t); } } if (includeStoppingTransition) { for (final int state : finalStateProbabilities.keys()) { if (state == currentState) { result.add(new Transition(currentState, currentState, Transition.STOP_TRAVERSING_SYMBOL, finalStateProbabilities.get(state))); } } } return result; } Random r = new Random(MasterSeed.nextLong()); public TimedSequence sampleSequence() { int currentState = START_STATE; final TIntList eventList = new TIntArrayList(); final TDoubleList timeList = new TDoubleArrayList(); boolean choseFinalState = false; AnomalyInsertionType anomalyType = AnomalyInsertionType.NONE; while (!choseFinalState) { final List<Transition> possibleTransitions = getTransitions(currentState, true); Collections.sort(possibleTransitions, (t1, t2) -> -Double.compare(t2.getProbability(), t1.getProbability())); final double random = r.nextDouble(); double summedProbs = 0; int index = -1; for (int i = 0; i < possibleTransitions.size(); i++) { summedProbs += possibleTransitions.get(i).getProbability(); if (random < summedProbs) { index = i; break; } } final Transition chosenTransition = possibleTransitions.get(index); // XXX What happens for sequence based anomalies if we first choose an abnormal transition and then a normal one? Should we enforce choosing the // abnormal transitions labeled with type 2 and 4 when the first of those anomalies was chosen? The problem are sequence based anomalies! if (chosenTransition.isAbnormal()) { if (anomalyType != AnomalyInsertionType.NONE && anomalyType != chosenTransition.getAnomalyInsertionType()) { // This is a conflict because the anomalyType was already set to anomaly throw new IllegalStateException("Two anomalies are mixed in this special case"); } anomalyType = chosenTransition.getAnomalyInsertionType(); // XXX what happens if one transition was normal and then the other one was abnormal or from another type? 0,1,2,0,0,5? What about the label for // the sequence? Is the label for the sequence really needed? } if (chosenTransition.isStopTraversingTransition() || eventList.size() > MAX_SEQUENCE_LENGTH) { choseFinalState = true; // TODO what happens if an abnormal stopping transiton (type 5) was chosen? } else { currentState = chosenTransition.getToState(); final Distribution d = transitionDistributions.get(chosenTransition.toZeroProbTransition()); if (d == null) { // XXX maybe this happens because the automaton is more general than the data. So not every possible path in the automaton is represented in // the training data. throw new IllegalStateException("This should never happen for transition " + chosenTransition); } double timeValue = d.sample(1, r)[0]; if (anomalyType == AnomalyInsertionType.TYPE_THREE) { timeValue = changeTimeValue(timeValue, ANOMALY_3_CHANGE_RATE); } else if (anomalyType == AnomalyInsertionType.TYPE_FOUR) { timeValue = changeTimeValue(timeValue, ANOMALY_4_CHANGE_RATE); } eventList.add(chosenTransition.getSymbol()); timeList.add(timeValue); } } if (anomalyType != AnomalyInsertionType.NONE) { return new TimedSequence(eventList, timeList, ClassLabel.ANOMALY); } else { return new TimedSequence(eventList, timeList, ClassLabel.NORMAL); } } private double changeTimeValue(double value, double factor) { if (r.nextBoolean()) { return value * -factor; } else { return value * factor; } } public Random getRandom() { return r; } public void setRandom(Random r) { this.r = r; } protected boolean isInAutomaton(TimedSequence s) { int currentState = START_STATE; for (int i = 0; i < s.length(); i++) { final int nextEvent = s.getEvent(i); final Transition t = getTransition(currentState, nextEvent); if (t == null) { return false; } currentState = t.getToState(); } if (getFinalStateProbability(currentState) > NO_TRANSITION_PROBABILITY) { return true; } else { return false; } } @Override public String toString() { return "PDTTA [alphabet=" + alphabet + "]"; } public int getStartState() { return START_STATE; } public int getStateCount() { return finalStateProbabilities.size(); } protected void removeState(int i) { finalStateProbabilities.remove(i); } }