org.apache.cassandra.tools.Shuffle.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.cassandra.tools.Shuffle.java

Source

package org.apache.cassandra.tools;
/*
 * 
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 * 
 */

import java.io.Closeable;
import java.io.IOException;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

import javax.management.JMX;
import javax.management.MBeanServerConnection;
import javax.management.MalformedObjectNameException;
import javax.management.ObjectName;

import org.apache.cassandra.serializers.TimestampSerializer;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.locator.EndpointSnitchInfoMBean;
import org.apache.cassandra.service.StorageServiceMBean;
import org.apache.cassandra.thrift.Cassandra;
import org.apache.cassandra.thrift.Compression;
import org.apache.cassandra.thrift.ConsistencyLevel;
import org.apache.cassandra.thrift.CqlResult;
import org.apache.cassandra.thrift.CqlRow;
import org.apache.cassandra.thrift.InvalidRequestException;
import org.apache.cassandra.thrift.TimedOutException;
import org.apache.cassandra.thrift.TokenRange;
import org.apache.cassandra.thrift.UnavailableException;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.MissingArgumentException;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.transport.TFastFramedTransport;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;

public class Shuffle extends AbstractJmxClient {
    private static final String ssObjName = "org.apache.cassandra.db:type=StorageService";
    private static final String epSnitchObjName = "org.apache.cassandra.db:type=EndpointSnitchInfo";

    private StorageServiceMBean ssProxy = null;
    private Random rand = new Random(System.currentTimeMillis());
    private final String thriftHost;
    private final int thriftPort;
    private final boolean thriftFramed;

    static {
        addCmdOption("th", "thrift-host", true, "Thrift hostname or IP address (Default: JMX host)");
        addCmdOption("tp", "thrift-port", true, "Thrift port number (Default: 9160)");
        addCmdOption("tf", "thrift-framed", false, "Enable framed transport for Thrift (Default: false)");
        addCmdOption("en", "and-enable", true, "Immediately enable shuffling (create only)");
        addCmdOption("dc", "only-dc", true, "Apply only to named DC (create only)");
    }

    public Shuffle(String host, int port) throws IOException {
        this(host, port, host, 9160, false, null, null);
    }

    public Shuffle(String host, int port, String thriftHost, int thriftPort, boolean thriftFramed, String username,
            String password) throws IOException {
        super(host, port, username, password);

        this.thriftHost = thriftHost;
        this.thriftPort = thriftPort;
        this.thriftFramed = thriftFramed;

        // Setup the StorageService proxy.
        ssProxy = getSSProxy(jmxConn.getMbeanServerConn());
    }

    public StorageServiceMBean getSSProxy(MBeanServerConnection mbeanConn) {
        StorageServiceMBean proxy = null;
        try {
            ObjectName name = new ObjectName(ssObjName);
            proxy = JMX.newMBeanProxy(mbeanConn, name, StorageServiceMBean.class);
        } catch (MalformedObjectNameException e) {
            throw new RuntimeException(e);
        }
        return proxy;
    }

    public EndpointSnitchInfoMBean getEpSnitchProxy(MBeanServerConnection mbeanConn) {
        EndpointSnitchInfoMBean proxy = null;
        try {
            ObjectName name = new ObjectName(epSnitchObjName);
            proxy = JMX.newMBeanProxy(mbeanConn, name, EndpointSnitchInfoMBean.class);
        } catch (MalformedObjectNameException e) {
            throw new RuntimeException(e);
        }
        return proxy;
    }

    /**
     * Given a Multimap of endpoint to tokens, return a new randomized mapping.
     *
     * @param endpointMap current mapping of endpoint to tokens
     * @return a new mapping of endpoint to tokens
     */
    public Multimap<String, String> calculateRelocations(Multimap<String, String> endpointMap) {
        Multimap<String, String> relocations = HashMultimap.create();
        Set<String> endpoints = new HashSet<String>(endpointMap.keySet());
        Map<String, Integer> endpointToNumTokens = new HashMap<String, Integer>(endpoints.size());
        Map<String, Iterator<String>> iterMap = new HashMap<String, Iterator<String>>(endpoints.size());

        // Create maps of endpoint to token iterators, and endpoint to number of tokens.
        for (String endpoint : endpoints) {
            try {
                endpointToNumTokens.put(endpoint, ssProxy.getTokens(endpoint).size());
            } catch (UnknownHostException e) {
                throw new RuntimeException("What that...?", e);
            }

            iterMap.put(endpoint, endpointMap.get(endpoint).iterator());
        }

        int epsToComplete = endpoints.size();
        Set<String> endpointsCompleted = new HashSet<String>();

        outer: while (true) {
            endpoints.removeAll(endpointsCompleted);

            for (String endpoint : endpoints) {
                boolean choiceMade = false;

                if (!iterMap.get(endpoint).hasNext()) {
                    endpointsCompleted.add(endpoint);
                    continue;
                }

                String token = iterMap.get(endpoint).next();

                List<String> subSet = new ArrayList<String>(endpoints);
                subSet.remove(endpoint);
                Collections.shuffle(subSet, rand);

                for (String choice : subSet) {
                    if (relocations.get(choice).size() < endpointToNumTokens.get(choice)) {
                        relocations.put(choice, token);
                        choiceMade = true;
                        break;
                    }
                }

                if (!choiceMade)
                    relocations.put(endpoint, token);
            }

            // We're done when we've exhausted all of the token iterators
            if (endpointsCompleted.size() == epsToComplete)
                break outer;
        }

        return relocations;
    }

    /**
     * Enable relocations.
     *
     * @param endpoints sequence of hostname or IP strings
     */
    public void enableRelocations(String... endpoints) {
        enableRelocations(Arrays.asList(endpoints));
    }

    /**
     * Enable relocations.
     *
     * @param endpoints Collection of hostname or IP strings
     */
    public void enableRelocations(Collection<String> endpoints) {
        for (String endpoint : endpoints) {
            try {
                JMXConnection conn = new JMXConnection(endpoint, port, username, password);
                getSSProxy(conn.getMbeanServerConn()).enableScheduledRangeXfers();
                conn.close();
            } catch (IOException e) {
                writeln("Failed to enable shuffling on %s!", endpoint);
            }
        }
    }

    /**
     * Disable relocations.
     *
     * @param endpoints sequence of hostname or IP strings
     */
    public void disableRelocations(String... endpoints) {
        disableRelocations(Arrays.asList(endpoints));
    }

    /**
     * Disable relocations.
     *
     * @param endpoints Collection of hostname or IP strings
     */
    public void disableRelocations(Collection<String> endpoints) {
        for (String endpoint : endpoints) {
            try {
                JMXConnection conn = new JMXConnection(endpoint, port, username, password);
                getSSProxy(conn.getMbeanServerConn()).disableScheduledRangeXfers();
                conn.close();
            } catch (IOException e) {
                writeln("Failed to enable shuffling on %s!", endpoint);
            }
        }
    }

    /**
     * Return a list of the live nodes (using JMX).
     *
     * @return String endpoint names
     * @throws ShuffleError
     */
    public Collection<String> getLiveNodes() throws ShuffleError {
        try {
            JMXConnection conn = new JMXConnection(host, port, username, password);
            return getSSProxy(conn.getMbeanServerConn()).getLiveNodes();
        } catch (IOException e) {
            throw new ShuffleError("Error retrieving list of nodes from JMX interface");
        }
    }

    /**
     * Create and distribute a new, randomized token to endpoint mapping.
     *
     * @throws ShuffleError on handled exceptions
     */
    public void shuffle(boolean enable, String onlyDc) throws ShuffleError {
        CassandraClient seedClient = null;
        Map<String, String> tokenMap = null;
        IPartitioner<?> partitioner = null;
        Multimap<String, String> endpointMap = HashMultimap.create();
        EndpointSnitchInfoMBean epSnitchProxy = getEpSnitchProxy(jmxConn.getMbeanServerConn());

        try {
            seedClient = getThriftClient(thriftHost, thriftPort, thriftFramed);
            tokenMap = seedClient.describe_token_map();

            for (Map.Entry<String, String> entry : tokenMap.entrySet()) {
                String endpoint = entry.getValue(), token = entry.getKey();
                try {
                    if (onlyDc != null) {
                        if (onlyDc.equals(epSnitchProxy.getDatacenter(endpoint)))
                            endpointMap.put(endpoint, token);
                    } else
                        endpointMap.put(endpoint, token);
                } catch (UnknownHostException e) {
                    writeln("Warning: %s unknown to EndpointSnitch!", endpoint);
                }
            }
        } catch (InvalidRequestException ire) {
            throw new RuntimeException("What that...?", ire);
        } catch (TException e) {
            throw new ShuffleError(
                    String.format("Thrift request to %s:%d failed: %s", thriftHost, thriftPort, e.getMessage()));
        }

        partitioner = getPartitioner(thriftHost, thriftPort, thriftFramed);

        Multimap<String, String> relocations = calculateRelocations(endpointMap);

        writeln("%-42s %-15s %-15s", "Token", "From", "To");
        writeln("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~");

        // Store relocations on remote nodes.
        for (String endpoint : relocations.keySet()) {
            for (String tok : relocations.get(endpoint))
                writeln("%-42s %-15s %-15s", tok, tokenMap.get(tok), endpoint);

            String cqlQuery = createShuffleBatchInsert(relocations.get(endpoint), partitioner);
            executeCqlQuery(endpoint, thriftPort, thriftFramed, cqlQuery);
        }

        if (enable)
            enableRelocations(relocations.keySet());
    }

    /**
     * Print a list of pending token relocations for all nodes.
     *
     * @throws ShuffleError
     */
    public void ls() throws ShuffleError {
        Map<String, List<CqlRow>> queuedRelocations = listRelocations();
        IPartitioner<?> partitioner = getPartitioner(thriftHost, thriftPort, thriftFramed);
        boolean justOnce = false;

        for (String host : queuedRelocations.keySet()) {
            for (CqlRow row : queuedRelocations.get(host)) {
                assert row.getColumns().size() == 2;

                if (!justOnce) {
                    writeln("%-42s %-15s %s", "Token", "Endpoint", "Requested at");
                    writeln("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~~~~~~~~~~~~~~");
                    justOnce = true;
                }

                ByteBuffer tokenBytes = ByteBuffer.wrap(row.getColumns().get(0).getValue());
                ByteBuffer requestedAt = ByteBuffer.wrap(row.getColumns().get(1).getValue());
                Date time = TimestampSerializer.instance.deserialize(requestedAt);
                Token<?> token = partitioner.getTokenFactory().fromByteArray(tokenBytes);

                writeln("%-42s %-15s %s", token.toString(), host, time.toString());
            }
        }
    }

    /**
     * List pending token relocations for all nodes.
     *
     * @return
     * @throws ShuffleError
     */
    private Map<String, List<CqlRow>> listRelocations() throws ShuffleError {
        String cqlQuery = "SELECT token_bytes,requested_at FROM system.range_xfers";
        Map<String, List<CqlRow>> results = new HashMap<String, List<CqlRow>>();

        for (String host : getLiveNodes()) {
            CqlResult result = executeCqlQuery(host, thriftPort, thriftFramed, cqlQuery);
            results.put(host, result.getRows());
        }

        return results;
    }

    /**
     * Clear pending token relocations on all nodes.
     *
     * @throws ShuffleError
     */
    public void clear() throws ShuffleError {
        Map<String, List<CqlRow>> queuedRelocations = listRelocations();

        for (String host : queuedRelocations.keySet()) {

            for (CqlRow row : queuedRelocations.get(host)) {
                assert row.getColumns().size() == 2;

                ByteBuffer tokenBytes = ByteBuffer.wrap(row.getColumns().get(0).getValue());
                String query = String.format("DELETE FROM system.range_xfers WHERE token_bytes = 0x%s",
                        ByteBufferUtil.bytesToHex(tokenBytes));
                executeCqlQuery(host, thriftPort, thriftFramed, query);
            }
        }
    }

    /**
     * Enable shuffling on all nodes in the cluster.
     *
     * @throws ShuffleError
     */
    public void enable() throws ShuffleError {
        enableRelocations(getLiveNodes());
    }

    /**
     * Disable shuffling on all nodes in the cluster.
     *
     * @throws ShuffleError
     */
    public void disable() throws ShuffleError {
        disableRelocations(getLiveNodes());
    }

    /**
     * Setup and return a new Thrift RPC connection.
     *
     * @param hostName hostname or address to connect to
     * @param port port number to connect to
     * @param framed wrap with framed transport if true
     * @return a CassandraClient instance
     * @throws ShuffleError
     */
    public static CassandraClient getThriftClient(String hostName, int port, boolean framed) throws ShuffleError {
        try {
            return new CassandraClient(hostName, port, framed);
        } catch (TTransportException e) {
            throw new ShuffleError(String.format("Unable to connect to %s/%d: %s", hostName, port, e.getMessage()));
        }
    }

    /**
     * Execute a CQL v3 query.
     *
     * @param hostName hostname or address to connect to
     * @param port port number to connect to
     * @param isFramed wrap with framed transport if true
     * @param cqlQuery CQL query string
     * @return a Thrift CqlResult instance
     * @throws ShuffleError
     */
    public static CqlResult executeCqlQuery(String hostName, int port, boolean isFramed, String cqlQuery)
            throws ShuffleError {
        CassandraClient client = null;

        try {
            client = getThriftClient(hostName, port, isFramed);
            return client.execute_cql_query(ByteBuffer.wrap(cqlQuery.getBytes()), Compression.NONE);
        } catch (UnavailableException e) {
            throw new ShuffleError(
                    String.format("Unable to write shuffle entries to %s. Reason: UnavailableException", hostName));
        } catch (TimedOutException e) {
            throw new ShuffleError(
                    String.format("Unable to write shuffle entries to %s. Reason: TimedOutException", hostName));
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            if (client != null)
                client.close();
        }
    }

    /**
     * Return a partitioner instance for remote host.
     *
     * @param hostName hostname or address to connect to
     * @param port port number to connect to
     * @param framed wrap with framed transport if true
     * @return an IPartitioner instance
     * @throws ShuffleError
     */
    public static IPartitioner<?> getPartitioner(String hostName, int port, boolean framed) throws ShuffleError {
        String partitionerName = null;
        try {
            partitionerName = getThriftClient(hostName, port, framed).describe_partitioner();
        } catch (InvalidRequestException e) {
            throw new RuntimeException("Error calling describe_partitioner() defies explanation", e);
        } catch (TException e) {
            throw new ShuffleError(
                    String.format("Thrift request to %s:%d failed: %s", hostName, port, e.getMessage()));
        }

        try {
            Class<?> partitionerClass = Class.forName(partitionerName);
            return (IPartitioner<?>) partitionerClass.newInstance();
        } catch (ClassNotFoundException e) {
            throw new ShuffleError("Unable to locate class for partitioner: " + partitionerName);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Create and return a CQL batch insert statement for a set of token relocations.
     *
     * @param tokens tokens to be relocated
     * @param partitioner an instance of the IPartitioner in use
     * @return a query string
     */
    public static String createShuffleBatchInsert(Collection<String> tokens, IPartitioner<?> partitioner) {
        StringBuilder query = new StringBuilder();
        query.append("BEGIN BATCH").append("\n");

        for (String tokenStr : tokens) {
            Token<?> token = partitioner.getTokenFactory().fromString(tokenStr);
            String hexToken = ByteBufferUtil.bytesToHex(partitioner.getTokenFactory().toByteArray(token));
            query.append("INSERT INTO system.range_xfers (token_bytes, requested_at) ").append("VALUES (")
                    .append("0x").append(hexToken).append(", 'now');").append("\n");
        }

        query.append("APPLY BATCH").append("\n");
        return query.toString();
    }

    /** Print usage information. */
    private static void printShuffleHelp() {
        StringBuilder sb = new StringBuilder();
        sb.append("Sub-commands:").append(String.format("%n"));
        sb.append(" create           Initialize a new shuffle operation").append(String.format("%n"));
        sb.append(" ls               List pending relocations").append(String.format("%n"));
        sb.append(" clear            Clear pending relocations").append(String.format("%n"));
        sb.append(" en[able]         Enable shuffling").append(String.format("%n"));
        sb.append(" dis[able]        Disable shuffling").append(String.format("%n%n"));

        printHelp("shuffle [options] <sub-command>", sb.toString());
    }

    /**
     * Execute.
     *
     * @param args arguments passed on the command line
     * @throws Exception when face meets palm
     */
    public static void main(String[] args) throws Exception {
        CommandLine cmd = null;
        try {
            cmd = processArguments(args);
        } catch (MissingArgumentException e) {
            System.err.println(e.getMessage());
            System.exit(1);
        }

        // Sub command argument.
        if (cmd.getArgList().size() < 1) {
            System.err.println("Missing sub-command argument.");
            printShuffleHelp();
            System.exit(1);
        }
        String subCommand = (String) (cmd.getArgList()).get(0);

        String hostName = (cmd.getOptionValue("host") != null) ? cmd.getOptionValue("host") : DEFAULT_HOST;
        String port = (cmd.getOptionValue("port") != null) ? cmd.getOptionValue("port")
                : Integer.toString(DEFAULT_JMX_PORT);
        String username = cmd.getOptionValue("username");
        String password = cmd.getOptionValue("password");
        String thriftHost = (cmd.getOptionValue("thrift-host") != null) ? cmd.getOptionValue("thrift-host")
                : hostName;
        String thriftPort = (cmd.getOptionValue("thrift-port") != null) ? cmd.getOptionValue("thrift-port")
                : "9160";
        String onlyDc = cmd.getOptionValue("only-dc");
        boolean thriftFramed = cmd.hasOption("thrift-framed") ? true : false;
        boolean andEnable = cmd.hasOption("and-enable") ? true : false;
        int portNum = -1, thriftPortNum = -1;

        // Parse JMX port number
        if (port != null) {
            try {
                portNum = Integer.parseInt(port);
            } catch (NumberFormatException ferr) {
                System.err.printf("%s is not a valid JMX port number.%n", port);
                System.exit(1);
            }
        } else
            portNum = DEFAULT_JMX_PORT;

        // Parse Thrift port number
        if (thriftPort != null) {
            try {
                thriftPortNum = Integer.parseInt(thriftPort);
            } catch (NumberFormatException ferr) {
                System.err.printf("%s is not a valid port number.%n", thriftPort);
                System.exit(1);
            }
        } else
            thriftPortNum = 9160;

        Shuffle shuffler = new Shuffle(hostName, portNum, thriftHost, thriftPortNum, thriftFramed, username,
                password);

        try {
            if (subCommand.equals("create"))
                shuffler.shuffle(andEnable, onlyDc);
            else if (subCommand.equals("ls"))
                shuffler.ls();
            else if (subCommand.startsWith("en"))
                shuffler.enable();
            else if (subCommand.startsWith("dis"))
                shuffler.disable();
            else if (subCommand.equals("clear"))
                shuffler.clear();
            else {
                System.err.println("Unknown subcommand: " + subCommand);
                printShuffleHelp();
                System.exit(1);
            }
        } catch (ShuffleError err) {
            shuffler.writeln(err);
            System.exit(1);
        } finally {
            shuffler.close();
        }

        System.exit(0);
    }
}

/** A self-contained Cassandra.Client; Closeable. */
class CassandraClient implements Closeable {
    TTransport transport;
    Cassandra.Client client;

    CassandraClient(String hostName, int port, boolean framed) throws TTransportException {
        TSocket socket = new TSocket(hostName, port);
        transport = (framed) ? socket : new TFastFramedTransport(socket);
        transport.open();
        client = new Cassandra.Client(new TBinaryProtocol(transport));

        try {
            client.set_cql_version("3.0.0");
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    CqlResult execute_cql_query(ByteBuffer cqlQuery, Compression compression) throws Exception {
        return client.execute_cql3_query(cqlQuery, compression, ConsistencyLevel.ONE);
    }

    String describe_partitioner() throws TException, InvalidRequestException {
        return client.describe_partitioner();
    }

    List<TokenRange> describe_ring(String keyspace) throws InvalidRequestException, TException {
        return client.describe_ring(keyspace);
    }

    Map<String, String> describe_token_map() throws InvalidRequestException, TException {
        return client.describe_token_map();
    }

    public void close() {
        transport.close();
    }
}

@SuppressWarnings("serial")
class ShuffleError extends Exception {
    ShuffleError(String msg) {
        super(msg);
    }
}