sadl.detectors.AnomalyDetector.java Source code

Java tutorial

Introduction

Here is the source code for sadl.detectors.AnomalyDetector.java

Source

/**
 * This file is part of SADL, a library for learning all sorts of (timed) automata and performing sequence-based anomaly detection.
 * Copyright (C) 2013-2016  the original author or authors.
 *
 * SADL is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
 *
 * SADL is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along with SADL.  If not, see <http://www.gnu.org/licenses/>.
 */
package sadl.detectors;

import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;

import org.apache.commons.math3.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import sadl.constants.ProbabilityAggregationMethod;
import sadl.input.TimedInput;
import sadl.input.TimedWord;
import sadl.interfaces.ProbabilisticModel;
import sadl.utils.Settings;

/**
 * 
 * @author Timo Klerx
 *
 */
public abstract class AnomalyDetector {
    private static Logger logger = LoggerFactory.getLogger(AnomalyDetector.class);
    public static final int ILLEGAL_VALUE = -1;

    protected ProbabilityAggregationMethod aggType;
    ProbabilisticModel model;

    public boolean isAnomaly(ProbabilisticModel newModel, TimedWord s) {
        setModel(newModel);
        return isAnomaly(s);
    }

    public boolean[] areAnomalies(ProbabilisticModel newModel, TimedInput testSequences) {
        setModel(newModel);
        return areAnomalies(testSequences);

    }

    public AnomalyDetector(ProbabilityAggregationMethod aggType) {
        super();
        this.aggType = aggType;
    }

    public AnomalyDetector(ProbabilityAggregationMethod aggType, ProbabilisticModel model) {
        super();
        this.aggType = aggType;
        this.model = model;
    }

    /**
     * returns two double values for every timed sequence. The first value is the event likelihood, the second the time likelihood
     * 
     * @param testTimedSequences
     */
    public List<double[]> computeAggregatedLikelihoods(TimedInput testTimedSequences) {
        final List<double[]> result = new ArrayList<>();
        for (final TimedWord ts : testTimedSequences) {
            final Pair<TDoubleList, TDoubleList> p = model.calculateProbabilities(ts);
            final double eventProb = aggregate(p.getKey(), aggType);
            final double timeProb = aggregate(p.getValue(), aggType);
            result.add(new double[] { eventProb, timeProb });
        }
        return result;
    }

    public Pair<TDoubleList, TDoubleList> computeAggregatedTrendLikelihood(TimedWord ts) {
        final Pair<TDoubleList, TDoubleList> p = model.calculateProbabilities(ts);
        return computeAggregatedTrendLikelihood(p.getKey(), p.getValue());
    }

    public Pair<TDoubleList, TDoubleList> computeAggregatedTrendLikelihood(TDoubleList eventLHs,
            TDoubleList timeLHs) {
        final TDoubleList partialEventLHs = new TDoubleArrayList();
        final TDoubleList partialTimeLHs = new TDoubleArrayList();
        for (int i = 1; i <= eventLHs.size(); i++) {
            final TDoubleList subList = eventLHs.subList(0, i);
            partialEventLHs.add(aggregate(subList, aggType));
        }
        for (int i = 1; i <= timeLHs.size(); i++) {
            final TDoubleList subList = timeLHs.subList(0, i);
            partialTimeLHs.add(aggregate(subList, aggType));
        }
        return Pair.create(partialEventLHs, partialTimeLHs);
    }

    public boolean isAnomaly(TimedWord s) {
        final Pair<TDoubleList, TDoubleList> p = model.calculateProbabilities(s);
        final TDoubleList eventLikelihoods = p.getKey();
        final TDoubleList timeLikelihoods = p.getValue();
        if (eventLikelihoods.size() < timeLikelihoods.size()) {
            throw new IllegalStateException(
                    "There must be at least as many event likelihoods as time likelihoods, but there are not: "
                            + eventLikelihoods.size() + "(events) vs. " + timeLikelihoods.size() + "(time values)");
        }
        return decide(eventLikelihoods, timeLikelihoods);
    }

    /**
     * Decides whether the likelihoods indicate an anomaly
     * 
     * @param eventLikelihoods
     * @param timeLikelihoods
     * @return true for anomaly, false otherwise
     */
    protected abstract boolean decide(TDoubleList eventLikelihoods, TDoubleList timeLikelihoods);

    public boolean[] areAnomalies(TimedInput testSequences) {
        if (Settings.isDebug()) {
            final Path testLabelFile = Paths.get("testLabels.csv");
            try {
                Files.deleteIfExists(testLabelFile);
                Files.createFile(testLabelFile);
            } catch (final IOException e1) {
                logger.error("Unexpected exception occured", e1);
            }
            try (BufferedWriter bw = Files.newBufferedWriter(testLabelFile, StandardCharsets.UTF_8)) {
                for (final TimedWord s : testSequences) {
                    bw.append(s.getLabel().toString());
                    bw.append('\n');
                }
            } catch (final IOException e) {
                logger.error("Unexpected exception occured", e);
            }
        }
        final boolean[] result = new boolean[testSequences.size()];

        // parallelism does not destroy determinism
        final IntConsumer f = (i -> {
            final TimedWord s = testSequences.get(i);
            result[i] = isAnomaly(s);
        });
        if (Settings.isParallel()) {
            IntStream.range(0, testSequences.size()).parallel().forEach(f);
        } else {
            IntStream.range(0, testSequences.size()).forEach(f);
        }
        return result;
    }

    public void setModel(ProbabilisticModel model) {
        this.model = model;
    }

    public static double aggregate(TDoubleList list, ProbabilityAggregationMethod aggType) {
        if (list.isEmpty()) {
            return ILLEGAL_VALUE;
        }
        double result = -1;
        if (aggType == ProbabilityAggregationMethod.MULTIPLY) {
            result = 0;
            for (int i = 0; i < list.size(); i++) {
                result += Math.log(list.get(i));
            }
        } else if (aggType == ProbabilityAggregationMethod.LUK_T) {
            result = list.get(0);
            for (int i = 1; i < list.size(); i++) {
                result = Math.max(0, list.get(i) + result - 1);
            }
        } else if (aggType == ProbabilityAggregationMethod.LUK_STRONG_DISJUNCTION) {
            result = list.get(0);
            for (int i = 1; i < list.size(); i++) {
                result = Math.min(1, list.get(i) + result);
            }
        } else if (aggType == ProbabilityAggregationMethod.NORMALIZED_MULTIPLY) {
            result = 0;
            for (int i = 0; i < list.size(); i++) {
                if (list.get(i) < 0) {
                    throw new IllegalStateException("Probability for index " + i + " is negative.");
                }
                result += Math.log(list.get(i));
            }
            result /= list.size();
            result = Math.exp(result);
        } else if (aggType == ProbabilityAggregationMethod.NORMALIZED_MULTIPLY_UNSTABLE) {
            result = 1;
            for (int i = 0; i < list.size(); i++) {
                result *= list.get(i);
            }
            result = Math.pow(result, 1.0 / list.size());
        }
        if (Double.isNaN(result)) {
            throw new IllegalStateException("Result of probability aggregation must not be NaN");
        }
        return result;
    }

    /**
     * Computes the product of the probabilities in log space. @param probabilities the probabilities @return the product of the probabilities in log space
     */
    public static double aggregate(TDoubleList probabilities) {
        return aggregate(probabilities, ProbabilityAggregationMethod.NORMALIZED_MULTIPLY);
    }
}