com.palantir.atlasdb.keyvalue.cassandra.CassandraClientFactory.java Source code

Java tutorial

Introduction

Here is the source code for com.palantir.atlasdb.keyvalue.cassandra.CassandraClientFactory.java

Source

/**
 * Copyright 2015 Palantir Technologies
 *
 * Licensed under the BSD-3 License (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://opensource.org/licenses/BSD-3-Clause
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.palantir.atlasdb.keyvalue.cassandra;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketException;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

import org.apache.cassandra.thrift.Cassandra;
import org.apache.cassandra.thrift.Cassandra.Client;
import org.apache.commons.pool2.BasePooledObjectFactory;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;

public class CassandraClientFactory extends BasePooledObjectFactory<Client> {
    private static final Logger log = LoggerFactory.getLogger(CassandraClientFactory.class);

    private static final LoadingCache<InetSocketAddress, SSLSocketFactory> sslSocketFactories = CacheBuilder
            .newBuilder().build(new CacheLoader<InetSocketAddress, SSLSocketFactory>() {
                @Override
                public SSLSocketFactory load(InetSocketAddress host) throws Exception {
                    /*
                     * Use a separate SSLSocketFactory per host to reduce contention on the synchronized method
                     * SecureRandom.nextBytes. Otherwise, this is identical to SSLSocketFactory.getDefault()
                     */
                    return SSLContext.getInstance("Default").getSocketFactory();
                }
            });

    private final InetSocketAddress addr;
    private final String keyspace;
    private final boolean isSsl;
    private final int socketTimeoutMillis;
    private final int socketQueryTimeoutMillis;

    public CassandraClientFactory(InetSocketAddress addr, String keyspace, boolean isSsl, int socketTimeoutMillis,
            int socketQueryTimeoutMillis) {
        this.addr = addr;
        this.keyspace = keyspace;
        this.isSsl = isSsl;
        this.socketTimeoutMillis = socketTimeoutMillis;
        this.socketQueryTimeoutMillis = socketQueryTimeoutMillis;
    }

    @Override
    public Client create() throws Exception {
        try {
            return getClient(addr, keyspace, isSsl, socketTimeoutMillis, socketQueryTimeoutMillis);
        } catch (Exception e) {
            String message = String.format("Failed to construct client for %s/%s", addr, keyspace);
            if (isSsl) {
                message += " over SSL";
            }
            throw new ClientCreationFailedException(message, e);
        }
    }

    private static Cassandra.Client getClient(InetSocketAddress addr, String keyspace, boolean isSsl,
            int socketTimeoutMillis, int socketQueryTimeoutMillis) throws Exception {
        Client ret = getClientInternal(addr, isSsl, socketTimeoutMillis, socketQueryTimeoutMillis);
        try {
            ret.set_keyspace(keyspace);
            log.info("Created new client for {}/{} {}", addr, keyspace, (isSsl ? "over SSL" : ""));
            return ret;
        } catch (Exception e) {
            ret.getOutputProtocol().getTransport().close();
            throw e;
        }
    }

    public static Cassandra.Client getClientInternal(InetSocketAddress addr, boolean isSsl, int socketTimeoutMillis,
            int socketQueryTimeoutMillis) throws TTransportException {
        TSocket tSocket = new TSocket(addr.getHostString(), addr.getPort(), socketTimeoutMillis);
        tSocket.open();
        try {
            tSocket.getSocket().setKeepAlive(true);
            tSocket.getSocket().setSoTimeout(socketQueryTimeoutMillis);
        } catch (SocketException e) {
            log.error("Couldn't set socket keep alive for {}", addr);
        }

        if (isSsl) {
            boolean success = false;
            try {
                SSLSocketFactory factory = sslSocketFactories.getUnchecked(addr);
                SSLSocket socket = (SSLSocket) factory.createSocket(tSocket.getSocket(), addr.getHostString(),
                        addr.getPort(), true);
                tSocket = new TSocket(socket);
                success = true;
            } catch (IOException e) {
                throw new TTransportException(e);
            } finally {
                if (!success) {
                    tSocket.close();
                }
            }
        }
        TTransport tFramedTransport = new TFramedTransport(tSocket,
                CassandraConstants.CLIENT_MAX_THRIFT_FRAME_SIZE_BYTES);
        TProtocol protocol = new TBinaryProtocol(tFramedTransport);
        Cassandra.Client client = new Cassandra.Client(protocol);
        return client;
    }

    @Override
    public boolean validateObject(PooledObject<Client> client) {
        return client.getObject().getOutputProtocol().getTransport().isOpen();
    }

    @Override
    public PooledObject<Client> wrap(Client client) {
        return new DefaultPooledObject<Client>(client);
    }

    @Override
    public void destroyObject(PooledObject<Client> client) {
        client.getObject().getOutputProtocol().getTransport().close();
        log.info("Closed transport for client {}", client.getObject());
    }

    static class ClientCreationFailedException extends RuntimeException {
        private static final long serialVersionUID = 1L;

        public ClientCreationFailedException(String message, Exception cause) {
            super(message, cause);
        }

        @Override
        public Exception getCause() {
            return (Exception) super.getCause();
        }
    }
}