com.twitter.distributedlog.service.balancer.ClusterBalancer.java Source code

Java tutorial

Introduction

Here is the source code for com.twitter.distributedlog.service.balancer.ClusterBalancer.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.twitter.distributedlog.service.balancer;

import com.google.common.base.Optional;
import com.google.common.util.concurrent.RateLimiter;
import com.twitter.distributedlog.client.monitor.MonitorServiceClient;
import com.twitter.distributedlog.service.ClientUtils;
import com.twitter.distributedlog.service.DLSocketAddress;
import com.twitter.distributedlog.service.DistributedLogClient;
import com.twitter.distributedlog.service.DistributedLogClientBuilder;
import com.twitter.util.Await;
import com.twitter.util.Function;
import com.twitter.util.Future;
import com.twitter.util.FutureEventListener;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * A balancer balances ownerships with a cluster of targets
 */
public class ClusterBalancer implements Balancer {

    static final Logger logger = LoggerFactory.getLogger(ClusterBalancer.class);

    /**
     * Represent a single host. Ordered by number of streams in desc order.
     */
    static class Host {

        final SocketAddress address;
        final Set<String> streams;
        final DistributedLogClientBuilder clientBuilder;
        DistributedLogClient client = null;
        MonitorServiceClient monitor = null;

        Host(SocketAddress address, Set<String> streams, DistributedLogClientBuilder clientBuilder) {
            this.address = address;
            this.streams = streams;
            this.clientBuilder = clientBuilder;
        }

        private void initializeClientsIfNeeded() {
            if (null == client) {
                Pair<DistributedLogClient, MonitorServiceClient> clientPair = createDistributedLogClient(address,
                        clientBuilder);
                client = clientPair.getLeft();
                monitor = clientPair.getRight();
            }
        }

        synchronized DistributedLogClient getClient() {
            initializeClientsIfNeeded();
            return client;
        }

        synchronized MonitorServiceClient getMonitor() {
            initializeClientsIfNeeded();
            return monitor;
        }

        synchronized void close() {
            if (null != client) {
                client.close();
            }
        }

        @Override
        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("Host(").append(address).append(")");
            return sb.toString();
        }
    }

    static class HostComparator implements Comparator<Host>, Serializable {
        private static final long serialVersionUID = 7984973796525102538L;

        @Override
        public int compare(Host h1, Host h2) {
            return h2.streams.size() - h1.streams.size();
        }
    }

    protected final DistributedLogClientBuilder clientBuilder;
    protected final DistributedLogClient client;
    protected final MonitorServiceClient monitor;

    public ClusterBalancer(DistributedLogClientBuilder clientBuilder) {
        this(clientBuilder, ClientUtils.buildClient(clientBuilder));
    }

    ClusterBalancer(DistributedLogClientBuilder clientBuilder,
            Pair<DistributedLogClient, MonitorServiceClient> clientPair) {
        this.clientBuilder = clientBuilder;
        this.client = clientPair.getLeft();
        this.monitor = clientPair.getRight();
    }

    /**
     * Build a new distributedlog client to a single host <i>host</i>.
     *
     * @param host
     *          host to access
     * @return distributedlog clients
     */
    static Pair<DistributedLogClient, MonitorServiceClient> createDistributedLogClient(SocketAddress host,
            DistributedLogClientBuilder clientBuilder) {
        DistributedLogClientBuilder newBuilder = DistributedLogClientBuilder.newBuilder(clientBuilder).host(host);
        return ClientUtils.buildClient(newBuilder);
    }

    @Override
    public void balanceAll(String source, int rebalanceConcurrency, /* unused */
            Optional<RateLimiter> rebalanceRateLimiter) {
        balance(0, 0.0f, rebalanceConcurrency, Optional.of(source), rebalanceRateLimiter);
    }

    @Override
    public void balance(int rebalanceWaterMark, double rebalanceTolerancePercentage,
            int rebalanceConcurrency, /* unused */
            Optional<RateLimiter> rebalanceRateLimiter) {
        Optional<String> source = Optional.absent();
        balance(rebalanceWaterMark, rebalanceTolerancePercentage, rebalanceConcurrency, source,
                rebalanceRateLimiter);
    }

    public void balance(int rebalanceWaterMark, double rebalanceTolerancePercentage, int rebalanceConcurrency,
            Optional<String> source, Optional<RateLimiter> rebalanceRateLimiter) {
        Map<SocketAddress, Set<String>> distribution = monitor.getStreamOwnershipDistribution();
        if (distribution.size() <= 1) {
            return;
        }
        SocketAddress sourceAddr = null;
        if (source.isPresent()) {
            sourceAddr = DLSocketAddress.parseSocketAddress(source.get());
            logger.info("Balancer source is {}", sourceAddr);
            if (!distribution.containsKey(sourceAddr)) {
                return;
            }
        }
        // Get the list of hosts ordered by number of streams in DESC order
        List<Host> hosts = new ArrayList<Host>(distribution.size());
        for (Map.Entry<SocketAddress, Set<String>> entry : distribution.entrySet()) {
            Host host = new Host(entry.getKey(), entry.getValue(), clientBuilder);
            hosts.add(host);
        }
        Collections.sort(hosts, new HostComparator());
        try {

            // find the host to move streams from.
            int hostIdxMoveFrom = -1;
            if (null != sourceAddr) {
                for (Host host : hosts) {
                    ++hostIdxMoveFrom;
                    if (sourceAddr.equals(host.address)) {
                        break;
                    }
                }
            }

            // compute the average load.
            int totalStream = 0;
            for (Host host : hosts) {
                totalStream += host.streams.size();
            }
            double averageLoad;
            if (hostIdxMoveFrom >= 0) {
                averageLoad = ((double) totalStream / (hosts.size() - 1));
            } else {
                averageLoad = ((double) totalStream / hosts.size());
            }

            int moveFromLowWaterMark;
            int moveToHighWaterMark = Math.max(1,
                    (int) (averageLoad + averageLoad * rebalanceTolerancePercentage / 100.0f));

            if (hostIdxMoveFrom >= 0) {
                moveFromLowWaterMark = Math.max(0, rebalanceWaterMark);
                moveStreams(hosts, new AtomicInteger(hostIdxMoveFrom), moveFromLowWaterMark,
                        new AtomicInteger(hosts.size() - 1), moveToHighWaterMark, rebalanceRateLimiter);
                moveRemainingStreamsFromSource(hosts.get(hostIdxMoveFrom), hosts, rebalanceRateLimiter);
            } else {
                moveFromLowWaterMark = Math.max((int) Math.ceil(averageLoad), rebalanceWaterMark);
                AtomicInteger moveFrom = new AtomicInteger(0);
                AtomicInteger moveTo = new AtomicInteger(hosts.size() - 1);
                while (moveFrom.get() < moveTo.get()) {
                    moveStreams(hosts, moveFrom, moveFromLowWaterMark, moveTo, moveToHighWaterMark,
                            rebalanceRateLimiter);
                    moveFrom.incrementAndGet();
                }
            }
        } finally {
            for (Host host : hosts) {
                host.close();
            }
        }
    }

    void moveStreams(List<Host> hosts, AtomicInteger hostIdxMoveFrom, int moveFromLowWaterMark,
            AtomicInteger hostIdxMoveTo, int moveToHighWaterMark, Optional<RateLimiter> rateLimiter) {
        if (hostIdxMoveFrom.get() < 0 || hostIdxMoveFrom.get() >= hosts.size() || hostIdxMoveTo.get() < 0
                || hostIdxMoveTo.get() >= hosts.size() || hostIdxMoveFrom.get() >= hostIdxMoveTo.get()) {
            return;
        }

        if (logger.isDebugEnabled()) {
            logger.debug(
                    "Moving streams : hosts = {}, from = {}, to = {} : from_low_water_mark = {}, to_high_water_mark = {}",
                    new Object[] { hosts, hostIdxMoveFrom.get(), hostIdxMoveTo.get(), moveFromLowWaterMark,
                            moveToHighWaterMark });
        }

        Host hostMoveFrom = hosts.get(hostIdxMoveFrom.get());
        int numStreamsOnFromHost = hostMoveFrom.streams.size();
        if (numStreamsOnFromHost <= moveFromLowWaterMark) {
            // do nothing
            return;
        }

        int numStreamsToMove = numStreamsOnFromHost - moveFromLowWaterMark;
        LinkedList<String> streamsToMove = new LinkedList<String>(hostMoveFrom.streams);
        Collections.shuffle(streamsToMove);

        if (logger.isDebugEnabled()) {
            logger.debug("Try to move {} streams from host {} : streams = {}",
                    new Object[] { numStreamsToMove, hostMoveFrom.address, streamsToMove });
        }

        while (numStreamsToMove-- > 0 && !streamsToMove.isEmpty()) {
            if (rateLimiter.isPresent()) {
                rateLimiter.get().acquire();
            }

            // pick a host to move
            Host hostMoveTo = hosts.get(hostIdxMoveTo.get());
            while (hostMoveTo.streams.size() >= moveToHighWaterMark) {
                int hostIdx = hostIdxMoveTo.decrementAndGet();
                logger.info("move to host : {}, from {}", hostIdx, hostIdxMoveFrom.get());
                if (hostIdx <= hostIdxMoveFrom.get()) {
                    return;
                } else {
                    hostMoveTo = hosts.get(hostIdx);
                    if (logger.isDebugEnabled()) {
                        logger.debug("Target host to move moved to host {} @ {}", hostIdx, hostMoveTo);
                    }
                }
            }

            // pick a stream
            String stream = streamsToMove.remove();

            // move the stream
            if (moveStream(stream, hostMoveFrom, hostMoveTo)) {
                hostMoveFrom.streams.remove(stream);
                hostMoveTo.streams.add(stream);
            }
        }

    }

    void moveRemainingStreamsFromSource(Host source, List<Host> hosts, Optional<RateLimiter> rateLimiter) {
        LinkedList<String> streamsToMove = new LinkedList<String>(source.streams);
        Collections.shuffle(streamsToMove);

        if (logger.isDebugEnabled()) {
            logger.debug("Try to move remaining streams from {} : {}", source, streamsToMove);
        }

        int hostIdx = hosts.size() - 1;

        while (!streamsToMove.isEmpty()) {
            if (rateLimiter.isPresent()) {
                rateLimiter.get().acquire();
            }

            Host target = hosts.get(hostIdx);
            if (!target.address.equals(source.address)) {
                String stream = streamsToMove.remove();
                // move the stream
                if (moveStream(stream, source, target)) {
                    source.streams.remove(stream);
                    target.streams.add(stream);
                }
            }
            --hostIdx;
            if (hostIdx < 0) {
                hostIdx = hosts.size() - 1;
            }
        }
    }

    private boolean moveStream(String stream, Host from, Host to) {
        try {
            doMoveStream(stream, from, to);
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    private void doMoveStream(final String stream, final Host from, final Host to) throws Exception {
        logger.info("Moving stream {} from {} to {}.", new Object[] { stream, from.address, to.address });
        Await.result(from.getClient().release(stream).flatMap(new Function<Void, Future<Void>>() {
            @Override
            public Future<Void> apply(Void result) {
                logger.info("Released stream {} from {}.", stream, from.address);
                return to.getMonitor().check(stream).addEventListener(new FutureEventListener<Void>() {

                    @Override
                    public void onSuccess(Void value) {
                        logger.info("Moved stream {} from {} to {}.",
                                new Object[] { stream, from.address, to.address });
                    }

                    @Override
                    public void onFailure(Throwable cause) {
                        logger.info("Failed to move stream {} from {} to {} : ",
                                new Object[] { stream, from.address, to.address, cause });
                    }
                });
            }
        }));
    }

    @Override
    public void close() {
        client.close();
    }
}