com.jbrisbin.riak.async.RiakAsyncClient.java Source code

Java tutorial

Introduction

Here is the source code for com.jbrisbin.riak.async.RiakAsyncClient.java

Source

/*
 * This file is provided to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.jbrisbin.riak.async;

import static com.google.protobuf.ByteString.*;
import static com.jbrisbin.riak.pbc.RiakMessageCodes.*;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;

import com.basho.riak.client.IRiakObject;
import com.basho.riak.client.bucket.BucketProperties;
import com.basho.riak.client.builders.BucketPropertiesBuilder;
import com.basho.riak.client.query.MapReduceResult;
import com.basho.riak.client.query.WalkResult;
import com.basho.riak.client.raw.RiakResponse;
import com.basho.riak.client.raw.StoreMeta;
import com.basho.riak.client.raw.query.LinkWalkSpec;
import com.basho.riak.client.raw.query.MapReduceTimeoutException;
import com.basho.riak.client.util.CharsetUtils;
import com.basho.riak.pbc.RPB;
import com.google.protobuf.ByteString;
import com.jbrisbin.riak.async.raw.RawAsyncClient;
import com.jbrisbin.riak.async.raw.ServerInfo;
import com.jbrisbin.riak.async.util.DelegatingErrorHandler;
import com.jbrisbin.riak.pbc.RiakObject;
import com.jbrisbin.riak.pbc.RpbFilter;
import com.jbrisbin.riak.pbc.RpbMessage;
import org.apache.commons.codec.binary.Base64;
import org.codehaus.jackson.map.ObjectMapper;
import org.glassfish.grizzly.Connection;
import org.glassfish.grizzly.EmptyCompletionHandler;
import org.glassfish.grizzly.filterchain.BaseFilter;
import org.glassfish.grizzly.filterchain.FilterChain;
import org.glassfish.grizzly.filterchain.FilterChainBuilder;
import org.glassfish.grizzly.filterchain.FilterChainContext;
import org.glassfish.grizzly.filterchain.NextAction;
import org.glassfish.grizzly.filterchain.TransportFilter;
import org.glassfish.grizzly.memory.HeapMemoryManager;
import org.glassfish.grizzly.nio.transport.TCPNIOTransport;
import org.glassfish.grizzly.nio.transport.TCPNIOTransportBuilder;
import org.glassfish.grizzly.strategies.WorkerThreadIOStrategy;
import org.glassfish.grizzly.threadpool.ThreadPoolConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author Jon Brisbin <jon@jbrisbin.com>
 */
public class RiakAsyncClient implements RawAsyncClient {

    private final Logger log = LoggerFactory.getLogger(getClass());
    private final int PROCESSORS = Runtime.getRuntime().availableProcessors();

    private String host = "localhost";
    private Integer port = 8087;
    private Long timeout = 30L;
    private DelegatingErrorHandler errorHandler = new DelegatingErrorHandler();
    private ExecutorService workerPool = Executors.newFixedThreadPool(PROCESSORS);
    private HeapMemoryManager heap = new HeapMemoryManager();
    private TCPNIOTransport transport = TCPNIOTransportBuilder.newInstance().build();
    private FilterChain filterChain;
    private Connection connection;
    private int maxConnectionRetries = 5;
    private AtomicInteger retries = new AtomicInteger(0);
    private LinkedBlockingQueue<RpbRequest> pendingRequests = new LinkedBlockingQueue<RpbRequest>();
    private ObjectMapper mapper = new ObjectMapper();

    public RiakAsyncClient() {
        start();
    }

    public RiakAsyncClient(String host, Integer port) {
        this.host = host;
        this.port = port;
        start();
    }

    public RiakAsyncClient(String host, Integer port, Long timeout) {
        this.host = host;
        this.port = port;
        this.timeout = timeout;
    }

    public RiakAsyncClient registerErrorHandler(Class<? extends Throwable> t, ErrorHandler errorHandler) {
        this.errorHandler.registerErrorHandler(t, errorHandler);
        return this;
    }

    public String getHost() {
        return host;
    }

    public RiakAsyncClient setHost(String host) {
        this.host = host;
        return this;
    }

    public Integer getPort() {
        return port;
    }

    public RiakAsyncClient setPort(Integer port) {
        this.port = port;
        return this;
    }

    public Long getTimeout() {
        return timeout;
    }

    public RiakAsyncClient setTimeout(Long timeout) {
        this.timeout = timeout;
        return this;
    }

    public ExecutorService getWorkerPool() {
        return workerPool;
    }

    public RiakAsyncClient setWorkerPool(ExecutorService workerPool) {
        this.workerPool = workerPool;
        return this;
    }

    public RiakAsyncClient setErrorHandler(DelegatingErrorHandler errorHandler) {
        this.errorHandler = errorHandler;
        return this;
    }

    public DelegatingErrorHandler getErrorHandler() {
        return errorHandler;
    }

    @SuppressWarnings({ "unchecked" })
    public void close() {
        try {
            connection.close(new EmptyCompletionHandler<Connection>() {
                @Override
                public void failed(Throwable throwable) {
                    errorHandler.handleError(throwable);
                }

                @Override
                public void completed(Connection result) {
                    try {
                        transport.stop();
                    } catch (IOException e) {
                        failed(e);
                    }
                    if (log.isDebugEnabled())
                        log.debug(String.format("Disconnected from %s:%s", host, port));
                }
            });
        } catch (IOException e) {
            errorHandler.handleError(e);
        }
    }

    private void start() {
        FilterChainBuilder clientChainBuilder = FilterChainBuilder.stateless();
        clientChainBuilder.add(new TransportFilter());
        clientChainBuilder.add(new RpbFilter(heap));
        clientChainBuilder.add(new PendingRequestFilter());
        filterChain = clientChainBuilder.build();

        transport.setKeepAlive(true);
        transport.setTcpNoDelay(true);
        ThreadPoolConfig config = ThreadPoolConfig.defaultConfig().setPoolName("riak-async-client")
                .setCorePoolSize(PROCESSORS).setMaxPoolSize(PROCESSORS).setKeepAliveTime(timeout, TimeUnit.SECONDS);
        transport.setWorkerThreadPoolConfig(config);
        transport.setIOStrategy(WorkerThreadIOStrategy.getInstance());
        transport.setProcessor(filterChain);
        try {
            transport.start();
            connection = getConnection();
        } catch (IOException e) {
            errorHandler.handleError(e);
        } catch (InterruptedException e) {
            errorHandler.handleError(e);
        } catch (ExecutionException e) {
            errorHandler.handleError(e);
        } catch (TimeoutException e) {
            errorHandler.handleError(e);
        }
    }

    private Connection getConnection()
            throws IOException, ExecutionException, TimeoutException, InterruptedException {
        if (null == connection || !connection.isOpen()) {
            if (retries.get() < maxConnectionRetries) {
                connection = transport
                        .connect(new InetSocketAddress(host, port), new EmptyCompletionHandler<Connection>() {
                            @Override
                            public void failed(Throwable throwable) {
                                errorHandler.handleError(throwable);
                                retries.incrementAndGet();
                            }

                            @Override
                            public void completed(Connection result) {
                                if (retries.get() > 0)
                                    retries.decrementAndGet();
                            }
                        }).get(timeout, TimeUnit.SECONDS);
                if (log.isDebugEnabled())
                    log.debug(String.format("Connected to %s:%s", host, port));
                generateAndSetClientId().get(timeout, TimeUnit.SECONDS);
            }
        }
        return connection;
    }

    @Override
    public Promise<RiakResponse> fetch(String bucket, String key) throws IOException {
        return fetch(bucket, key, -1);
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<RiakResponse> fetch(String bucket, String key, int readQuorum) throws IOException {
        RPB.RpbGetReq.Builder b = RPB.RpbGetReq.newBuilder().setBucket(copyFromUtf8(bucket))
                .setKey(copyFromUtf8(key));
        if (readQuorum > 0) {
            b.setR(readQuorum);
        }
        RpbMessage<RPB.RpbGetReq> msg = new RpbMessage<RPB.RpbGetReq>(MSG_GetReq, b.build());
        final Promise<RiakResponse> promise = new Promise<RiakResponse>();
        try {
            getConnection().write(new RpbRequest(msg, promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<RiakResponse> store(final IRiakObject object, final StoreMeta storeMeta) {
        RPB.RpbPutReq.Builder b = RPB.RpbPutReq.newBuilder().setBucket(copyFromUtf8(object.getBucket()))
                .setKey(copyFromUtf8(object.getKey()));
        if (null != object.getVClock())
            b.setVclock(copyFrom(object.getVClock().getBytes()));
        if (object instanceof RiakObject) {
            b.setContent(((RiakObject) object).build());
        }
        if (null != storeMeta) {
            if (null != storeMeta.getReturnBody())
                b.setReturnBody(storeMeta.getReturnBody());

            if (null != storeMeta.getDw())
                b.setDw(storeMeta.getDw());

            if (null != storeMeta.getW())
                b.setW(storeMeta.getW());
        }
        RpbMessage<RPB.RpbPutReq> msg = new RpbMessage<RPB.RpbPutReq>(MSG_PutReq, b.build());
        final Promise<RiakResponse> promise = new Promise<RiakResponse>();
        try {
            getConnection().write(new RpbRequest(msg, promise), new EmptyCompletionHandler() {
                @Override
                public void failed(Throwable throwable) {
                    retries.incrementAndGet();
                    if (retries.get() < maxConnectionRetries) {
                        if (log.isDebugEnabled())
                            log.debug("Retrying request: " + object);
                        store(object, storeMeta);
                    }
                }
            });
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @Override
    public Promise<RiakResponse> store(IRiakObject object) throws IOException {
        return store(object, null);
    }

    @Override
    public Promise<Void> delete(String bucket, String key) throws IOException {
        return delete(bucket, key, -1);
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<Void> delete(String bucket, String key, int deleteQuorum) {
        RPB.RpbDelReq.Builder b = RPB.RpbDelReq.newBuilder().setBucket(copyFromUtf8(bucket))
                .setKey(copyFromUtf8(key));
        if (deleteQuorum > -1)
            b.setRw(deleteQuorum);
        RpbMessage<RPB.RpbDelReq> msg = new RpbMessage<RPB.RpbDelReq>(MSG_DelReq, b.build());
        final Promise<Void> promise = new Promise<Void>();
        try {
            getConnection().write(new RpbRequest(msg, promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<Set<String>> listBuckets() {
        Promise<Set<String>> promise = new Promise<Set<String>>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_ListBucketsReq, null), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<BucketProperties> fetchBucket(String bucketName) {
        RPB.RpbGetBucketReq.Builder b = RPB.RpbGetBucketReq.newBuilder().setBucket(copyFromUtf8(bucketName));
        Promise<BucketProperties> promise = new Promise<BucketProperties>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_GetBucketReq, b.build()), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<Void> updateBucket(String name, BucketProperties bucketProperties) throws IOException {
        RPB.RpbBucketProps.Builder pb = RPB.RpbBucketProps.newBuilder();
        if (null != bucketProperties.getAllowSiblings()) {
            pb.setAllowMult(bucketProperties.getAllowSiblings());
        }
        if (null != bucketProperties.getNVal()) {
            pb.setNVal(bucketProperties.getNVal());
        }

        RPB.RpbSetBucketReq.Builder b = RPB.RpbSetBucketReq.newBuilder().setBucket(copyFromUtf8(name)).setProps(pb);
        Promise<Void> promise = new Promise<Void>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_SetBucketReq, b.build()), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<Iterable<String>> listKeys(String bucketName) throws IOException {
        RPB.RpbListKeysReq.Builder b = RPB.RpbListKeysReq.newBuilder().setBucket(copyFromUtf8(bucketName));
        Promise<Iterable<String>> promise = new Promise<Iterable<String>>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_ListKeysReq, b.build()), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @Override
    public Promise<WalkResult> linkWalk(LinkWalkSpec linkWalkSpec) throws IOException {
        throw new UnsupportedOperationException("Link walking not yet supported");
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<MapReduceResult> mapReduce(String json) throws IOException, MapReduceTimeoutException {
        RPB.RpbMapRedReq.Builder b = RPB.RpbMapRedReq.newBuilder().setContentType(copyFromUtf8("application/json"))
                .setRequest(copyFromUtf8(json));
        Promise<MapReduceResult> promise = new Promise<MapReduceResult>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_MapRedReq, b.build()), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<byte[]> generateAndSetClientId() throws IOException {
        SecureRandom sr;
        try {
            sr = SecureRandom.getInstance("SHA1PRNG");
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
        byte[] data = new byte[6];
        sr.nextBytes(data);
        String clientId = CharsetUtils.asString(Base64.encodeBase64Chunked(data), CharsetUtils.ISO_8859_1);
        RPB.RpbSetClientIdReq.Builder b = RPB.RpbSetClientIdReq.newBuilder().setClientId(copyFromUtf8(clientId));
        Promise<byte[]> promise = new Promise<byte[]>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_SetClientIdReq, b.build()), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<Void> setClientId(byte[] clientId) throws IOException {
        RPB.RpbSetClientIdReq.Builder b = RPB.RpbSetClientIdReq.newBuilder().setClientId(copyFrom(clientId));
        Promise<Void> promise = new Promise<Void>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_SetClientIdReq, b.build()), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<byte[]> getClientId() throws IOException {
        Promise<byte[]> promise = new Promise<byte[]>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_GetClientIdReq, null), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    @SuppressWarnings({ "unchecked" })
    @Override
    public Promise<ServerInfo> getServerInfo() throws IOException {
        Promise<ServerInfo> promise = new Promise<ServerInfo>();
        try {
            getConnection().write(new RpbRequest(new RpbMessage(MSG_GetServerInfoReq, null), promise));
        } catch (Exception e) {
            errorHandler.handleError(e);
        }
        return promise;
    }

    private List<RiakObject> createRiakObjectsFromContent(List<RPB.RpbContent> contents) {
        List<RiakObject> robjs = new ArrayList<RiakObject>(contents.size());
        for (RPB.RpbContent content : contents) {
            RiakObject robj = new RiakObject();
            robj.setVtag(content.getVtag());
            robj.setContentType(content.getContentType());
            robj.setContentEncoding(content.getContentEncoding());
            robj.setLastModified(content.getLastMod());
            robj.setLastModifiedUsec(content.getLastModUsecs());
            robj.setValue(content.getValue());
            robjs.add(robj);
        }
        return robjs;
    }

    private class PendingRequestFilter extends BaseFilter {
        @Override
        public NextAction handleConnect(FilterChainContext ctx) throws IOException {
            pendingRequests.clear();
            return ctx.getInvokeAction();
        }

        @SuppressWarnings({ "unchecked" })
        @Override
        public NextAction handleRead(final FilterChainContext ctx) throws IOException {
            final RpbMessage<?> msg = ctx.getMessage();
            final RpbRequest pending = pendingRequests.poll();
            if (null == pending) {
                return ctx.getStopAction();
            }
            if (log.isDebugEnabled()) {
                log.debug(String.format("Incoming message: " + msg));
                log.debug(String.format("Pending request: " + pending));
                //            log.debug(String.format("requests=" + requests.get() + ", responses=" + responses.incrementAndGet()));
            }

            Object response = null;
            switch (msg.getCode()) {
            case MSG_ErrorResp:
                RPB.RpbErrorResp errorResp = (RPB.RpbErrorResp) msg.getMessage();
                pending.getPromise().setFailure(new RiakException(errorResp.getErrmsg().toStringUtf8()));
                break;
            case MSG_DelResp:
                response = Void.INSTANCE;
                break;
            case MSG_GetBucketResp:
                RPB.RpbGetBucketResp bucketResp = (RPB.RpbGetBucketResp) msg.getMessage();
                if (log.isDebugEnabled())
                    log.debug(String.format("GetBucket response: %s", bucketResp));
                if (bucketResp.hasProps()) {
                    RPB.RpbBucketProps bucketProps = bucketResp.getProps();
                    BucketProperties props = new BucketPropertiesBuilder().allowSiblings(bucketProps.getAllowMult())
                            .nVal(bucketProps.getNVal()).build();
                    response = props;
                }
                break;
            case MSG_GetClientIdResp:
                RPB.RpbGetClientIdResp clientIdResp = (RPB.RpbGetClientIdResp) msg.getMessage();
                if (log.isDebugEnabled())
                    log.debug(String.format("GetClientId response: %s", clientIdResp));
                response = clientIdResp.getClientId().toByteArray();
                break;
            case MSG_GetResp:
                RPB.RpbGetResp getResp = ((RPB.RpbGetResp) msg.getMessage());
                if (log.isDebugEnabled())
                    log.debug(String.format("Get response: %s", getResp));
                List<RiakObject> robjs = createRiakObjectsFromContent(getResp.getContentList());
                response = new RiakResponse(getResp.getVclock().toByteArray(),
                        robjs.toArray(new RiakObject[robjs.size()]));
                break;
            case MSG_GetServerInfoResp:
                RPB.RpbGetServerInfoResp serverInfoResp = (RPB.RpbGetServerInfoResp) msg.getMessage();
                if (log.isDebugEnabled())
                    log.debug(String.format("GetServerInfo response: %s", serverInfoResp));
                response = new ServerInfo(serverInfoResp.getNode().toStringUtf8(),
                        serverInfoResp.getServerVersion().toStringUtf8());
                break;
            case MSG_ListBucketsResp:
                RPB.RpbListBucketsResp listBucketsResp = (RPB.RpbListBucketsResp) msg.getMessage();
                if (log.isDebugEnabled())
                    log.debug(String.format("ListBuckets response: %s", listBucketsResp));
                Set<String> buckets = new HashSet<String>();
                for (ByteString s : listBucketsResp.getBucketsList()) {
                    buckets.add(s.toStringUtf8());
                }
                response = buckets;
                break;
            case MSG_ListKeysResp:
                RPB.RpbListKeysResp listKeysResp = (RPB.RpbListKeysResp) msg.getMessage();
                if (log.isDebugEnabled())
                    log.debug(String.format("ListKeys response: %s", listKeysResp));
                Set<String> keys = new HashSet<String>();
                for (ByteString s : listKeysResp.getKeysList()) {
                    keys.add(s.toStringUtf8());
                }
                response = keys;
                break;
            case MSG_MapRedResp:
                RPB.RpbMapRedResp mapRedResp = (RPB.RpbMapRedResp) msg.getMessage();
                if (log.isDebugEnabled())
                    log.debug(String.format("MapRed response: %s", mapRedResp));
                String json = mapRedResp.getResponse().toStringUtf8();
                RiakAsyncMapReduceResult mapRedResult = new RiakAsyncMapReduceResult();
                mapRedResult.setResultRaw(json);
                response = mapRedResult;
                break;
            case MSG_PingResp:
                break;
            case MSG_PutResp:
                RPB.RpbPutResp putResp = (RPB.RpbPutResp) msg.getMessage();
                if (log.isDebugEnabled())
                    log.debug(String.format("Put response: %s", putResp));
                robjs = createRiakObjectsFromContent(putResp.getContentsList());
                response = new RiakResponse(putResp.getVclock().toByteArray(),
                        robjs.toArray(new RiakObject[robjs.size()]));
                break;
            case MSG_SetBucketResp:
                response = Void.INSTANCE;
                break;
            case MSG_SetClientIdResp:
                RPB.RpbSetClientIdReq clientIdReq = (RPB.RpbSetClientIdReq) pending.getMessage().getMessage();
                response = clientIdReq.getClientId().toByteArray();
                break;
            }
            if (null != response)
                pending.getPromise().setResult(response);

            if (log.isDebugEnabled()) {
                long elapsed = System.currentTimeMillis() - pending.getStart();
                String msgClazz = "null";
                if (null != pending.getMessage().getMessage()) {
                    msgClazz = pending.getMessage().getMessage().getClass().getSimpleName();
                }
                log.debug(String.format("Round-trip time for %s: %s", msgClazz, elapsed));
            }

            return ctx.getStopAction();
        }

        @Override
        public NextAction handleWrite(FilterChainContext ctx) throws IOException {
            RpbRequest pending = ctx.getMessage();
            pendingRequests.add(pending);
            ctx.setMessage(pending.getMessage());
            return ctx.getInvokeAction();
        }

        @Override
        public void exceptionOccurred(FilterChainContext ctx, Throwable error) {
            errorHandler.handleError(error);
        }
    }

    private class RpbRequest<T> {
        private Long start = System.currentTimeMillis();
        private RpbMessage message;
        private Promise<T> promise;

        private RpbRequest(RpbMessage message, Promise<T> promise) {
            this.message = message;
            this.promise = promise;
        }

        public void setStart() {
            start = System.currentTimeMillis();
        }

        public Long getStart() {
            return start;
        }

        public RpbMessage getMessage() {
            return message;
        }

        public Promise<T> getPromise() {
            return promise;
        }

        @Override
        public String toString() {
            String s = "RpbRequest{" + "start=" + start + ", message=" + message + ", promise=" + promise + "}";
            return s;
        }
    }

}