io.airlift.drift.transport.netty.client.ThriftClientHandler.java Source code

Java tutorial

Introduction

Here is the source code for io.airlift.drift.transport.netty.client.ThriftClientHandler.java

Source

/*
 * Copyright (C) 2013 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.airlift.drift.transport.netty.client;

import com.google.common.util.concurrent.AbstractFuture;
import io.airlift.drift.TApplicationException;
import io.airlift.drift.TException;
import io.airlift.drift.codec.ThriftCodec;
import io.airlift.drift.codec.internal.ProtocolReader;
import io.airlift.drift.codec.internal.ProtocolWriter;
import io.airlift.drift.codec.metadata.ThriftType;
import io.airlift.drift.protocol.TMessage;
import io.airlift.drift.protocol.TProtocolReader;
import io.airlift.drift.protocol.TProtocolWriter;
import io.airlift.drift.protocol.TTransportException;
import io.airlift.drift.transport.MethodMetadata;
import io.airlift.drift.transport.ParameterMetadata;
import io.airlift.drift.transport.client.DriftApplicationException;
import io.airlift.drift.transport.client.MessageTooLargeException;
import io.airlift.drift.transport.client.RequestTimeoutException;
import io.airlift.drift.transport.netty.codec.FrameInfo;
import io.airlift.drift.transport.netty.codec.FrameTooLargeException;
import io.airlift.drift.transport.netty.codec.Protocol;
import io.airlift.drift.transport.netty.codec.ThriftFrame;
import io.airlift.drift.transport.netty.codec.Transport;
import io.airlift.drift.transport.netty.ssl.TChannelBufferInputTransport;
import io.airlift.drift.transport.netty.ssl.TChannelBufferOutputTransport;
import io.airlift.units.Duration;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.ScheduledFuture;

import javax.annotation.concurrent.ThreadSafe;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.drift.TApplicationException.Type.BAD_SEQUENCE_ID;
import static io.airlift.drift.TApplicationException.Type.INVALID_MESSAGE_TYPE;
import static io.airlift.drift.TApplicationException.Type.MISSING_RESULT;
import static io.airlift.drift.TApplicationException.Type.WRONG_METHOD_NAME;
import static io.airlift.drift.protocol.TMessageType.CALL;
import static io.airlift.drift.protocol.TMessageType.EXCEPTION;
import static io.airlift.drift.protocol.TMessageType.ONEWAY;
import static io.airlift.drift.protocol.TMessageType.REPLY;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

@ThreadSafe
public class ThriftClientHandler extends ChannelDuplexHandler {
    private static final int ONEWAY_SEQUENCE_ID = 0xFFFF_FFFF;

    private final Duration requestTimeout;
    private final Transport transport;
    private final Protocol protocol;

    private final ConcurrentHashMap<Integer, RequestHandler> pendingRequests = new ConcurrentHashMap<>();
    private final AtomicReference<TException> channelError = new AtomicReference<>();
    private final AtomicInteger sequenceId = new AtomicInteger(42);

    ThriftClientHandler(Duration requestTimeout, Transport transport, Protocol protocol) {
        this.requestTimeout = requireNonNull(requestTimeout, "requestTimeout is null");
        this.transport = requireNonNull(transport, "transport is null");
        this.protocol = requireNonNull(protocol, "protocol is null");
    }

    @Override
    public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) throws Exception {
        if (message instanceof ThriftRequest) {
            ThriftRequest thriftRequest = (ThriftRequest) message;
            sendMessage(ctx, thriftRequest, promise);
        } else {
            ctx.write(message, promise);
        }
    }

    private void sendMessage(ChannelHandlerContext context, ThriftRequest thriftRequest, ChannelPromise promise)
            throws Exception {
        // todo ONEWAY_SEQUENCE_ID is a header protocol thing... make sure this works with framed and unframed
        int sequenceId = thriftRequest.isOneway() ? ONEWAY_SEQUENCE_ID : this.sequenceId.incrementAndGet();
        RequestHandler requestHandler = new RequestHandler(thriftRequest, sequenceId);

        // register timeout
        requestHandler.registerRequestTimeout(context.executor());

        // write request
        ByteBuf requestBuffer = requestHandler.encodeRequest(context.alloc());

        // register request if we are expecting a response
        if (!thriftRequest.isOneway()) {
            if (pendingRequests.putIfAbsent(sequenceId, requestHandler) != null) {
                requestHandler.onChannelError(
                        new TTransportException("Another request with the same sequenceId is already in progress"));
                requestBuffer.release();
                return;
            }
        }

        // if this connection is failed, immediately fail the request
        TException channelError = this.channelError.get();
        if (channelError != null) {
            thriftRequest.failed(channelError);
            requestBuffer.release();
            return;
        }

        try {
            ThriftFrame thriftFrame = new ThriftFrame(sequenceId, requestBuffer, thriftRequest.getHeaders(),
                    transport, protocol, true);

            ChannelFuture sendFuture = context.write(thriftFrame, promise);
            sendFuture.addListener(future -> messageSent(context, sendFuture, requestHandler));
        } catch (Throwable t) {
            onError(context, t, Optional.of(requestHandler));
            requestBuffer.release();
        }
    }

    private void messageSent(ChannelHandlerContext context, ChannelFuture future, RequestHandler requestHandler) {
        try {
            if (!future.isSuccess()) {
                onError(context, new TTransportException("Sending request failed", future.cause()),
                        Optional.of(requestHandler));
                return;
            }

            requestHandler.onRequestSent();
        } catch (Throwable t) {
            onError(context, t, Optional.of(requestHandler));
        }
    }

    @Override
    public void channelRead(ChannelHandlerContext context, Object message) {
        if (message instanceof ThriftFrame) {
            messageReceived(context, (ThriftFrame) message);
            return;
        }
        context.fireChannelRead(message);
    }

    private void messageReceived(ChannelHandlerContext context, ThriftFrame thriftFrame) {
        RequestHandler requestHandler = null;
        try {
            requestHandler = pendingRequests.remove(thriftFrame.getSequenceId());
            if (requestHandler == null) {
                throw new TTransportException("Unknown sequence id in response: " + thriftFrame.getSequenceId());
            }

            requestHandler.onResponseReceived(thriftFrame.retain());
        } catch (Throwable t) {
            onError(context, t, Optional.ofNullable(requestHandler));
        } finally {
            thriftFrame.release();
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext context, Throwable cause) {
        onError(context, cause, Optional.empty());
    }

    @Override
    public void channelInactive(ChannelHandlerContext context) {
        onError(context, new TTransportException("Client was disconnected by server"), Optional.empty());
    }

    private void onError(ChannelHandlerContext context, Throwable throwable,
            Optional<RequestHandler> currentRequest) {
        if (throwable instanceof FrameTooLargeException) {
            checkArgument(!currentRequest.isPresent(),
                    "current request should not be set for FrameTooLargeException");
            onFrameTooLargeException(context, (FrameTooLargeException) throwable);
            return;
        }

        TException thriftException;
        if (throwable instanceof TException) {
            thriftException = (TException) throwable;
        } else {
            thriftException = new TTransportException(throwable);
        }

        // set channel error
        if (!channelError.compareAndSet(null, thriftException)) {
            // another thread is already tearing down this channel
            return;
        }

        // current request may have already been removed from pendingRequests, so notify it directly
        currentRequest.ifPresent(request -> {
            pendingRequests.remove(request.getSequenceId());
            request.onChannelError(thriftException);
        });

        // notify all pending requests of the error
        // Note while loop should not be necessary since this class should be single
        // threaded, but it is better to be safe in cleanup code
        while (!pendingRequests.isEmpty()) {
            pendingRequests.values().removeIf(request -> {
                request.onChannelError(thriftException);
                return true;
            });
        }

        context.close();
    }

    private void onFrameTooLargeException(ChannelHandlerContext context,
            FrameTooLargeException frameTooLargeException) {
        TException thriftException = new MessageTooLargeException(frameTooLargeException.getMessage(),
                frameTooLargeException);
        Optional<FrameInfo> frameInfo = frameTooLargeException.getFrameInfo();
        if (frameInfo.isPresent()) {
            RequestHandler request = pendingRequests.remove(frameInfo.get().getSequenceId());
            if (request != null) {
                request.onChannelError(thriftException);
                return;
            }
        }
        // if sequence id is missing - fail all requests on a give channel
        onError(context,
                new MessageTooLargeException("unexpected too large response happened on communication channel",
                        frameTooLargeException),
                Optional.empty());
    }

    public static class ThriftRequest extends AbstractFuture<Object> {
        private final MethodMetadata method;
        private final List<Object> parameters;
        private final Map<String, String> headers;

        public ThriftRequest(MethodMetadata method, List<Object> parameters, Map<String, String> headers) {
            this.method = method;
            this.parameters = parameters;
            this.headers = headers;
        }

        MethodMetadata getMethod() {
            return method;
        }

        List<Object> getParameters() {
            return parameters;
        }

        public Map<String, String> getHeaders() {
            return headers;
        }

        boolean isOneway() {
            return method.isOneway();
        }

        void setResponse(Object response) {
            set(response);
        }

        void failed(Throwable throwable) {
            setException(throwable);
        }
    }

    private final class RequestHandler {
        private final ThriftRequest thriftRequest;
        private final int sequenceId;

        private final AtomicBoolean finished = new AtomicBoolean();
        private final AtomicReference<ScheduledFuture<?>> timeout = new AtomicReference<>();

        public RequestHandler(ThriftRequest thriftRequest, int sequenceId) {
            this.thriftRequest = thriftRequest;
            this.sequenceId = sequenceId;
        }

        public int getSequenceId() {
            return sequenceId;
        }

        void registerRequestTimeout(EventExecutor executor) {
            try {
                timeout.set(executor.schedule(
                        () -> onChannelError(new RequestTimeoutException(
                                "Timed out waiting " + requestTimeout + " to receive response")),
                        requestTimeout.toMillis(), MILLISECONDS));
            } catch (Throwable throwable) {
                onChannelError(new TTransportException("Unable to schedule request timeout", throwable));
                throw throwable;
            }
        }

        ByteBuf encodeRequest(ByteBufAllocator allocator) throws Exception {
            TChannelBufferOutputTransport transport = new TChannelBufferOutputTransport(allocator);
            try {
                TProtocolWriter protocolWriter = protocol.createProtocol(transport);

                // Note that though setting message type to ONEWAY can be helpful when looking at packet
                // captures, some clients always send CALL and so servers are forced to rely on the "oneway"
                // attribute on thrift method in the interface definition, rather than checking the message
                // type.
                MethodMetadata method = thriftRequest.getMethod();
                protocolWriter.writeMessageBegin(
                        new TMessage(method.getName(), method.isOneway() ? ONEWAY : CALL, sequenceId));

                // write the parameters
                ProtocolWriter writer = new ProtocolWriter(protocolWriter);
                writer.writeStructBegin(method.getName() + "_args");
                List<Object> parameters = thriftRequest.getParameters();
                for (int i = 0; i < parameters.size(); i++) {
                    Object value = parameters.get(i);
                    ParameterMetadata parameter = method.getParameters().get(i);
                    writer.writeField(parameter.getName(), parameter.getFieldId(), parameter.getCodec(), value);
                }
                writer.writeStructEnd();

                protocolWriter.writeMessageEnd();
                return transport.getBuffer();
            } catch (Throwable throwable) {
                onChannelError(throwable);
                throw throwable;
            } finally {
                transport.release();
            }
        }

        void onRequestSent() {
            if (!thriftRequest.isOneway()) {
                return;
            }

            if (!finished.compareAndSet(false, true)) {
                return;
            }

            try {
                cancelRequestTimeout();
                thriftRequest.setResponse(null);
            } catch (Throwable throwable) {
                onChannelError(throwable);
            }
        }

        void onResponseReceived(ThriftFrame thriftFrame) {
            try {
                if (!finished.compareAndSet(false, true)) {
                    return;
                }

                cancelRequestTimeout();
                Object response = decodeResponse(thriftFrame.getMessage());
                thriftRequest.setResponse(response);
            } catch (Throwable throwable) {
                thriftRequest.failed(throwable);
            } finally {
                thriftFrame.release();
            }
        }

        Object decodeResponse(ByteBuf responseMessage) throws Exception {
            TChannelBufferInputTransport transport = new TChannelBufferInputTransport(responseMessage);
            try {
                TProtocolReader protocolReader = protocol.createProtocol(transport);
                MethodMetadata method = thriftRequest.getMethod();

                // validate response header
                TMessage message = protocolReader.readMessageBegin();
                if (message.getType() == EXCEPTION) {
                    TApplicationException exception = ExceptionReader.readTApplicationException(protocolReader);
                    protocolReader.readMessageEnd();
                    throw exception;
                }
                if (message.getType() != REPLY) {
                    throw new TApplicationException(INVALID_MESSAGE_TYPE,
                            format("Received invalid message type %s from server", message.getType()));
                }
                if (!message.getName().equals(method.getName())) {
                    throw new TApplicationException(WRONG_METHOD_NAME,
                            format("Wrong method name in reply: expected %s but received %s", method.getName(),
                                    message.getName()));
                }
                if (message.getSequenceId() != sequenceId) {
                    throw new TApplicationException(BAD_SEQUENCE_ID,
                            format("%s failed: out of sequence response", method.getName()));
                }

                // read response struct
                ProtocolReader reader = new ProtocolReader(protocolReader);
                reader.readStructBegin();

                Object results = null;
                Exception exception = null;
                while (reader.nextField()) {
                    if (reader.getFieldId() == 0) {
                        results = reader.readField(method.getResultCodec());
                    } else {
                        ThriftCodec<Object> exceptionCodec = method.getExceptionCodecs().get(reader.getFieldId());
                        if (exceptionCodec != null) {
                            exception = (Exception) reader.readField(exceptionCodec);
                        } else {
                            reader.skipFieldData();
                        }
                    }
                }
                reader.readStructEnd();
                protocolReader.readMessageEnd();

                if (exception != null) {
                    throw new DriftApplicationException(exception);
                }

                if (method.getResultCodec().getType() == ThriftType.VOID) {
                    return null;
                }

                if (results == null) {
                    throw new TApplicationException(MISSING_RESULT,
                            format("%s failed: unknown result", method.getName()));
                }
                return results;
            } finally {
                transport.release();
            }
        }

        void onChannelError(Throwable requestException) {
            if (!finished.compareAndSet(false, true)) {
                return;
            }

            try {
                cancelRequestTimeout();
            } finally {
                thriftRequest.failed(requestException);
            }
        }

        private void cancelRequestTimeout() {
            ScheduledFuture<?> timeout = this.timeout.get();
            if (timeout != null) {
                timeout.cancel(false);
            }
        }
    }
}