io.druid.server.coordinator.cost.SegmentsCostCache.java Source code

Java tutorial

Introduction

Here is the source code for io.druid.server.coordinator.cost.SegmentsCostCache.java

Source

/*
 * Licensed to Metamarkets Group Inc. (Metamarkets) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. Metamarkets licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package io.druid.server.coordinator.cost;

import com.google.common.base.Preconditions;
import com.google.common.collect.Ordering;
import io.druid.java.util.common.ISE;
import io.druid.java.util.common.Intervals;
import io.druid.java.util.common.granularity.DurationGranularity;
import io.druid.java.util.common.guava.Comparators;
import io.druid.server.coordinator.CostBalancerStrategy;
import io.druid.timeline.DataSegment;
import org.apache.commons.math3.util.FastMath;
import org.joda.time.Interval;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.ListIterator;
import java.util.NavigableMap;
import java.util.NavigableSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

/**
 * SegmentsCostCache provides faster way to calculate cost function proposed in {@link CostBalancerStrategy}.
 * See https://github.com/druid-io/druid/pull/2972 for more details about the cost function.
 *
 * Joint cost for two segments (you can make formulas below readable by copy-pasting to
 * https://www.codecogs.com/latex/eqneditor.php):
 *
 *        cost(X, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy
 * or
 *        cost(X, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1})  (*)
 *                                                                          if x_0 <= x_1 <= y_0 <= y_1
 * (*) lambda coefficient is omitted for simplicity.
 *
 * For a group of segments {S_xi}, i = {0, n} total joint cost with segment S_y could be calculated as:
 *
 *        cost(X, Y) = \sum cost(X_i, Y) =  e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1})
 *                                                                          if xi_0 <= xi_1 <= y_0 <= y_1
 * and
 *        cost(X, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1})
 *                                                                          if y_0 <= y_1 <= xi_0 <= xi_1
 *
 * SegmentsCostCache stores pre-computed sums for a group of segments {S_xi}:
 *
 *      1) \sum (e^{xi_0} - e^{xi_1})                      ->  leftSum
 *      2) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1})      ->  rightSum
 *
 * so that calculation of joint cost function for segment S_y became a O(1 + m) complexity task, where m
 * is the number of segments in {S_xi} that overlaps S_y.
 *
 * Segments are stored in buckets. Bucket is a subset of segments contained in SegmentsCostCache, so that
 * startTime of all segments inside a bucket are in the same time interval (with some granularity):
 *
 *  |------------------------|--------------------------|-----------------------|--------  ....
 *  t_0                    t_0+D                     t_0 + 2D                t0 + 3D       ....
 *      S_x1  S_x2  S_x3          S_x4  S_x5  S_x6          S_x7  S_x8  S_x9
 *         bucket1                  bucket2                    bucket3
 *
 * Reasons to store segments in Buckets:
 *
 *     1) Cost function tends to 0 as distance between segments' intervals increases; buckets
 *        are used to avoid redundant 0 calculations for thousands of times
 *     2) To reduce number of calculations when segment is added or removed from SegmentsCostCache
 *     3) To avoid infinite values during exponents calculations
 *
 */
public class SegmentsCostCache {
    /**
     * HALF_LIFE_DAYS defines how fast joint cost function tends to 0 as distance between segments' intervals increasing.
     * The value of 1 day means that cost function of co-locating two segments which have 1 days between their intervals
     * is 0.5 of the cost, if the intervals are adjacent. If the distance is 2 days, then 0.25, etc.
     */
    private static final double HALF_LIFE_DAYS = 1.0;
    private static final double LAMBDA = Math.log(2) / HALF_LIFE_DAYS;
    private static final double MILLIS_FACTOR = TimeUnit.DAYS.toMillis(1) / LAMBDA;

    /**
     * LIFE_THRESHOLD is used to avoid calculations for segments that are "far"
     * from each other and thus cost(X,Y) ~ 0 for these segments
     */
    private static final long LIFE_THRESHOLD = TimeUnit.DAYS.toMillis(30);

    /**
     * Bucket interval defines duration granularity for segment buckets. Number of buckets control the trade-off
     * between updates (add/remove segment operation) and joint cost calculation:
     *        1) updates complexity is increasing when number of buckets is decreasing (as buckets contain more segments)
     *        2) joint cost calculation complexity is increasing with increasing of buckets number
     */
    private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15);
    private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0);

    private static final Comparator<DataSegment> SEGMENT_INTERVAL_COMPARATOR = Comparator
            .comparing(DataSegment::getInterval, Comparators.intervalsByStartThenEnd());

    private static final Comparator<Bucket> BUCKET_INTERVAL_COMPARATOR = Comparator.comparing(Bucket::getInterval,
            Comparators.intervalsByStartThenEnd());

    private static final Ordering<DataSegment> SEGMENT_ORDERING = Ordering.from(SEGMENT_INTERVAL_COMPARATOR);
    private static final Ordering<Bucket> BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR);

    private final ArrayList<Bucket> sortedBuckets;
    private final ArrayList<Interval> intervals;

    SegmentsCostCache(ArrayList<Bucket> sortedBuckets) {
        this.sortedBuckets = Preconditions.checkNotNull(sortedBuckets, "buckets should not be null");
        this.intervals = sortedBuckets.stream().map(Bucket::getInterval)
                .collect(Collectors.toCollection(ArrayList::new));
        Preconditions.checkArgument(BUCKET_ORDERING.isOrdered(sortedBuckets),
                "buckets must be ordered by interval");
    }

    public double cost(DataSegment segment) {
        double cost = 0.0;
        int index = Collections.binarySearch(intervals, segment.getInterval(),
                Comparators.intervalsByStartThenEnd());
        index = (index >= 0) ? index : -index - 1;

        for (ListIterator<Bucket> it = sortedBuckets.listIterator(index); it.hasNext();) {
            Bucket bucket = it.next();
            if (!bucket.inCalculationInterval(segment)) {
                break;
            }
            cost += bucket.cost(segment);
        }

        for (ListIterator<Bucket> it = sortedBuckets.listIterator(index); it.hasPrevious();) {
            Bucket bucket = it.previous();
            if (!bucket.inCalculationInterval(segment)) {
                break;
            }
            cost += bucket.cost(segment);
        }

        return cost;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private NavigableMap<Interval, Bucket.Builder> buckets = new TreeMap<>(
                Comparators.intervalsByStartThenEnd());

        public Builder addSegment(DataSegment segment) {
            Bucket.Builder builder = buckets.computeIfAbsent(getBucketInterval(segment), Bucket::builder);
            builder.addSegment(segment);
            return this;
        }

        public Builder removeSegment(DataSegment segment) {
            Interval interval = getBucketInterval(segment);
            buckets.computeIfPresent(interval,
                    // If there are no move segments, returning null in computeIfPresent() removes the interval from the buckets
                    // map
                    (i, builder) -> builder.removeSegment(segment).isEmpty() ? null : builder);
            return this;
        }

        public boolean isEmpty() {
            return buckets.isEmpty();
        }

        public SegmentsCostCache build() {
            return new SegmentsCostCache(buckets.entrySet().stream().map(entry -> entry.getValue().build())
                    .collect(Collectors.toCollection(ArrayList::new)));
        }

        private static Interval getBucketInterval(DataSegment segment) {
            return BUCKET_GRANULARITY.bucket(segment.getInterval().getStart());
        }
    }

    static class Bucket {
        private final Interval interval;
        private final Interval calculationInterval;
        private final ArrayList<DataSegment> sortedSegments;
        private final double[] leftSum;
        private final double[] rightSum;

        Bucket(Interval interval, ArrayList<DataSegment> sortedSegments, double[] leftSum, double[] rightSum) {
            this.interval = Preconditions.checkNotNull(interval, "interval");
            this.sortedSegments = Preconditions.checkNotNull(sortedSegments, "sortedSegments");
            this.leftSum = Preconditions.checkNotNull(leftSum, "leftSum");
            this.rightSum = Preconditions.checkNotNull(rightSum, "rightSum");
            Preconditions.checkArgument(
                    sortedSegments.size() == leftSum.length && sortedSegments.size() == rightSum.length);
            Preconditions.checkArgument(SEGMENT_ORDERING.isOrdered(sortedSegments));
            this.calculationInterval = new Interval(interval.getStart().minus(LIFE_THRESHOLD),
                    interval.getEnd().plus(LIFE_THRESHOLD));
        }

        Interval getInterval() {
            return interval;
        }

        boolean inCalculationInterval(DataSegment dataSegment) {
            return calculationInterval.overlaps(dataSegment.getInterval());
        }

        double cost(DataSegment dataSegment) {
            // cost is calculated relatively to bucket start (which is considered as 0)
            double t0 = convertStart(dataSegment, interval);
            double t1 = convertEnd(dataSegment, interval);

            // avoid calculation for segments outside of LIFE_THRESHOLD
            if (!inCalculationInterval(dataSegment)) {
                throw new ISE("Segment is not within calculation interval");
            }

            int index = Collections.binarySearch(sortedSegments, dataSegment, SEGMENT_INTERVAL_COMPARATOR);
            index = (index >= 0) ? index : -index - 1;
            return addLeftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index);
        }

        private double addLeftCost(DataSegment dataSegment, double t0, double t1, int index) {
            double leftCost = 0.0;
            // add to cost all left-overlapping segments
            int leftIndex = index - 1;
            while (leftIndex >= 0
                    && sortedSegments.get(leftIndex).getInterval().overlaps(dataSegment.getInterval())) {
                double start = convertStart(sortedSegments.get(leftIndex), interval);
                double end = convertEnd(sortedSegments.get(leftIndex), interval);
                leftCost += CostBalancerStrategy.intervalCost(end - start, t0 - start, t1 - start);
                --leftIndex;
            }
            // add left-non-overlapping segments
            if (leftIndex >= 0) {
                leftCost += leftSum[leftIndex] * (FastMath.exp(-t1) - FastMath.exp(-t0));
            }
            return leftCost;
        }

        private double rightCost(DataSegment dataSegment, double t0, double t1, int index) {
            double rightCost = 0.0;
            // add all right-overlapping segments
            int rightIndex = index;
            while (rightIndex < sortedSegments.size()
                    && sortedSegments.get(rightIndex).getInterval().overlaps(dataSegment.getInterval())) {
                double start = convertStart(sortedSegments.get(rightIndex), interval);
                double end = convertEnd(sortedSegments.get(rightIndex), interval);
                rightCost += CostBalancerStrategy.intervalCost(t1 - t0, start - t0, end - t0);
                ++rightIndex;
            }
            // add right-non-overlapping segments
            if (rightIndex < sortedSegments.size()) {
                rightCost += rightSum[rightIndex] * (FastMath.exp(t0) - FastMath.exp(t1));
            }
            return rightCost;
        }

        private static double convertStart(DataSegment dataSegment, Interval interval) {
            return toLocalInterval(dataSegment.getInterval().getStartMillis(), interval);
        }

        private static double convertEnd(DataSegment dataSegment, Interval interval) {
            return toLocalInterval(dataSegment.getInterval().getEndMillis(), interval);
        }

        private static double toLocalInterval(long millis, Interval interval) {
            return (millis - interval.getStartMillis()) / MILLIS_FACTOR;
        }

        public static Builder builder(Interval interval) {
            return new Builder(interval);
        }

        static class Builder {
            private final Interval interval;
            private final NavigableSet<SegmentAndSum> segments = new TreeSet<>();

            public Builder(Interval interval) {
                this.interval = interval;
            }

            public Builder addSegment(DataSegment dataSegment) {
                if (!interval.contains(dataSegment.getInterval().getStartMillis())) {
                    throw new ISE("Failed to add segment to bucket: interval is not covered by this bucket");
                }

                // all values are pre-computed relatively to bucket start (which is considered as 0)
                double t0 = convertStart(dataSegment, interval);
                double t1 = convertEnd(dataSegment, interval);

                double leftValue = FastMath.exp(t0) - FastMath.exp(t1);
                double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0);

                SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, leftValue, rightValue);

                // left/right value should be added to left/right sums for elements greater/lower than current segment
                segments.tailSet(segmentAndSum).forEach(v -> v.leftSum += leftValue);
                segments.headSet(segmentAndSum).forEach(v -> v.rightSum += rightValue);

                // leftSum_i = leftValue_i + \sum leftValue_j = leftValue_i + leftSum_{i-1} , j < i
                SegmentAndSum lower = segments.lower(segmentAndSum);
                if (lower != null) {
                    segmentAndSum.leftSum = leftValue + lower.leftSum;
                }

                // rightSum_i = rightValue_i + \sum rightValue_j = rightValue_i + rightSum_{i+1} , j > i
                SegmentAndSum higher = segments.higher(segmentAndSum);
                if (higher != null) {
                    segmentAndSum.rightSum = rightValue + higher.rightSum;
                }

                if (!segments.add(segmentAndSum)) {
                    throw new ISE("expect new segment");
                }
                return this;
            }

            public Builder removeSegment(DataSegment dataSegment) {
                SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, 0.0, 0.0);

                if (!segments.remove(segmentAndSum)) {
                    return this;
                }

                double t0 = convertStart(dataSegment, interval);
                double t1 = convertEnd(dataSegment, interval);

                double leftValue = FastMath.exp(t0) - FastMath.exp(t1);
                double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0);

                segments.tailSet(segmentAndSum).forEach(v -> v.leftSum -= leftValue);
                segments.headSet(segmentAndSum).forEach(v -> v.rightSum -= rightValue);
                return this;
            }

            public boolean isEmpty() {
                return segments.isEmpty();
            }

            public Bucket build() {
                ArrayList<DataSegment> segmentsList = new ArrayList<>(segments.size());
                double[] leftSum = new double[segments.size()];
                double[] rightSum = new double[segments.size()];
                int i = 0;
                for (SegmentAndSum segmentAndSum : segments) {
                    segmentsList.add(segmentAndSum.dataSegment);
                    leftSum[i] = segmentAndSum.leftSum;
                    rightSum[i] = segmentAndSum.rightSum;
                    ++i;
                }
                long bucketEndMillis = segmentsList.stream().mapToLong(s -> s.getInterval().getEndMillis()).max()
                        .orElseGet(interval::getEndMillis);
                return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), segmentsList, leftSum,
                        rightSum);
            }
        }
    }

    static class SegmentAndSum implements Comparable<SegmentAndSum> {
        private final DataSegment dataSegment;
        private double leftSum;
        private double rightSum;

        SegmentAndSum(DataSegment dataSegment, double leftSum, double rightSum) {
            this.dataSegment = dataSegment;
            this.leftSum = leftSum;
            this.rightSum = rightSum;
        }

        @Override
        public int compareTo(SegmentAndSum o) {
            int c = Comparators.intervalsByStartThenEnd().compare(dataSegment.getInterval(),
                    o.dataSegment.getInterval());
            return (c != 0) ? c : dataSegment.compareTo(o.dataSegment);
        }

        @Override
        public boolean equals(Object obj) {
            throw new UnsupportedOperationException("Use SegmentAndSum.compareTo()");
        }

        @Override
        public int hashCode() {
            throw new UnsupportedOperationException();
        }
    }
}