Java tutorial
/* * Copyright (C) 2013 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.airlift.drift.transport.netty.server; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Primitives; import com.google.common.util.concurrent.FluentFuture; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.drift.TApplicationException; import io.airlift.drift.codec.ThriftCodec; import io.airlift.drift.codec.internal.ProtocolReader; import io.airlift.drift.codec.internal.ProtocolWriter; import io.airlift.drift.protocol.TMessage; import io.airlift.drift.protocol.TMessageType; import io.airlift.drift.protocol.TProtocolReader; import io.airlift.drift.protocol.TProtocolWriter; import io.airlift.drift.protocol.TTransport; import io.airlift.drift.transport.MethodMetadata; import io.airlift.drift.transport.ParameterMetadata; import io.airlift.drift.transport.netty.codec.FrameInfo; import io.airlift.drift.transport.netty.codec.FrameTooLargeException; import io.airlift.drift.transport.netty.codec.Protocol; import io.airlift.drift.transport.netty.codec.ThriftFrame; import io.airlift.drift.transport.netty.codec.Transport; import io.airlift.drift.transport.netty.ssl.TChannelBufferInputTransport; import io.airlift.drift.transport.netty.ssl.TChannelBufferOutputTransport; import io.airlift.drift.transport.server.ServerInvokeRequest; import io.airlift.drift.transport.server.ServerMethodInvoker; import io.airlift.log.Logger; import io.airlift.units.Duration; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import java.io.IOException; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.OptionalDouble; import java.util.OptionalInt; import java.util.OptionalLong; import java.util.concurrent.ScheduledExecutorService; import java.util.regex.Pattern; import static com.google.common.base.Defaults.defaultValue; import static com.google.common.base.Strings.nullToEmpty; import static com.google.common.util.concurrent.Futures.immediateFailedFuture; import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.drift.TApplicationException.Type.INTERNAL_ERROR; import static io.airlift.drift.TApplicationException.Type.INVALID_MESSAGE_TYPE; import static io.airlift.drift.TApplicationException.Type.PROTOCOL_ERROR; import static io.airlift.drift.TApplicationException.Type.UNKNOWN_METHOD; import static io.airlift.drift.protocol.TMessageType.EXCEPTION; import static io.airlift.drift.protocol.TMessageType.REPLY; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.regex.Pattern.CASE_INSENSITIVE; public class ThriftServerHandler extends ChannelDuplexHandler { private static final Logger log = Logger.get(ThriftServerHandler.class); private static final Pattern CONNECTION_CLOSED_MESSAGE = Pattern .compile("^.*(?:connection.*(?:reset|closed|abort|broken)|broken.*pipe).*$", CASE_INSENSITIVE); private final ServerMethodInvoker methodInvoker; private final ScheduledExecutorService timeoutExecutor; private final Duration requestTimeout; public ThriftServerHandler(ServerMethodInvoker methodInvoker, Duration requestTimeout, ScheduledExecutorService timeoutExecutor) { this.methodInvoker = requireNonNull(methodInvoker, "methodInvoker is null"); this.requestTimeout = requireNonNull(requestTimeout, "requestTimeout is null"); this.timeoutExecutor = requireNonNull(timeoutExecutor, "timeoutExecutor is null"); } @Override public void channelRead(ChannelHandlerContext context, Object message) { if (message instanceof ThriftFrame) { messageReceived(context, (ThriftFrame) message); return; } context.fireChannelRead(message); } @Override public void exceptionCaught(ChannelHandlerContext context, Throwable cause) { // if possible, try to reply with an exception in case of a too large request if (cause instanceof FrameTooLargeException) { FrameTooLargeException e = (FrameTooLargeException) cause; // frame info may be missing in case of a large, but invalid request if (e.getFrameInfo().isPresent()) { FrameInfo frameInfo = e.getFrameInfo().get(); try { context.writeAndFlush(writeApplicationException(context, frameInfo.getMethodName(), frameInfo.getTransport(), frameInfo.getProtocol(), frameInfo.getSequenceId(), frameInfo.isSupportOutOfOrderResponse(), PROTOCOL_ERROR, e.getMessage(), e)); } catch (Throwable t) { context.close(); log.error(t, "Failed to write frame info"); } return; } } context.close(); // Don't log connection closed exceptions if (!isConnectionClosed(cause)) { log.error(cause); } } private void messageReceived(ChannelHandlerContext context, ThriftFrame frame) { TChannelBufferInputTransport inputTransport = new TChannelBufferInputTransport(frame.getMessage()); try { ListenableFuture<ThriftFrame> response = decodeMessage(context, inputTransport, frame.getTransport(), frame.getProtocol(), frame.getHeaders(), frame.isSupportOutOfOrderResponse()); Futures.addCallback(response, new FutureCallback<ThriftFrame>() { @Override public void onSuccess(ThriftFrame result) { context.writeAndFlush(result); } @Override public void onFailure(Throwable t) { context.disconnect(); } }, directExecutor()); } catch (Exception e) { log.error(e, "Exception processing request"); context.disconnect(); } catch (Throwable e) { log.error(e, "Error processing request"); context.disconnect(); throw e; } finally { inputTransport.release(); frame.release(); } } private ListenableFuture<ThriftFrame> decodeMessage(ChannelHandlerContext context, TTransport messageData, Transport transport, Protocol protocol, Map<String, String> headers, boolean supportOutOfOrderResponse) throws Exception { long start = System.nanoTime(); TProtocolReader protocolReader = protocol.createProtocol(messageData); TMessage message = protocolReader.readMessageBegin(); Optional<MethodMetadata> methodMetadata = methodInvoker.getMethodMetadata(message.getName()); if (!methodMetadata.isPresent()) { return immediateFuture(writeApplicationException(context, message.getName(), transport, protocol, message.getSequenceId(), supportOutOfOrderResponse, UNKNOWN_METHOD, "Invalid method name: '" + message.getName() + "'", null)); } MethodMetadata method = methodMetadata.get(); if (message.getType() != TMessageType.CALL && message.getType() != TMessageType.ONEWAY) { return immediateFuture(writeApplicationException(context, message.getName(), transport, protocol, message.getSequenceId(), supportOutOfOrderResponse, INVALID_MESSAGE_TYPE, "Invalid method message type: '" + message.getType() + "'", null)); } Map<Short, Object> parameters = readArguments(method, protocolReader); ListenableFuture<Object> result = methodInvoker .invoke(new ServerInvokeRequest(method, headers, parameters)); methodInvoker.recordResult(message.getName(), start, result); return FluentFuture.from(result).transformAsync(value -> { try { return immediateFuture(writeSuccessResponse(context, method, transport, protocol, message.getSequenceId(), supportOutOfOrderResponse, value)); } catch (Exception e) { return immediateFailedFuture(e); } }, directExecutor()).withTimeout(requestTimeout.toMillis(), MILLISECONDS, timeoutExecutor) .catchingAsync(Exception.class, exception -> { try { return immediateFuture(writeExceptionResponse(context, method, transport, protocol, message.getSequenceId(), supportOutOfOrderResponse, exception)); } catch (Exception e) { return immediateFailedFuture(e); } }, directExecutor()); } private static Map<Short, Object> readArguments(MethodMetadata method, TProtocolReader protocol) throws Exception { Map<Short, Object> arguments = new HashMap<>(method.getParameters().size()); ProtocolReader reader = new ProtocolReader(protocol); reader.readStructBegin(); while (reader.nextField()) { short fieldId = reader.getFieldId(); ParameterMetadata parameter = method.getParameterByFieldId(fieldId); if (parameter == null) { reader.skipFieldData(); } else { arguments.put(fieldId, reader.readField(parameter.getCodec())); } } reader.readStructEnd(); // set defaults for missing arguments for (ParameterMetadata parameter : method.getParameters()) { if (!arguments.containsKey(parameter.getFieldId())) { Type argumentType = parameter.getCodec().getType().getJavaType(); Object defaultValue = null; if (argumentType instanceof Class) { Class<?> argumentClass = (Class<?>) argumentType; if (argumentClass.isPrimitive()) { defaultValue = defaultValue(Primitives.unwrap(argumentClass)); } else if (argumentClass == OptionalInt.class) { defaultValue = OptionalInt.empty(); } else if (argumentClass == OptionalLong.class) { defaultValue = OptionalLong.empty(); } else if (argumentClass == OptionalDouble.class) { defaultValue = OptionalDouble.empty(); } } else if ((argumentType instanceof ParameterizedType) && (((ParameterizedType) argumentType).getRawType().equals(Optional.class))) { defaultValue = Optional.empty(); } arguments.put(parameter.getFieldId(), defaultValue); } } return arguments; } private static ThriftFrame writeSuccessResponse(ChannelHandlerContext context, MethodMetadata methodMetadata, Transport transport, Protocol protocol, int sequenceId, boolean supportOutOfOrderResponse, Object result) throws Exception { TChannelBufferOutputTransport outputTransport = new TChannelBufferOutputTransport(context.alloc()); try { writeResponse(methodMetadata.getName(), protocol.createProtocol(outputTransport), sequenceId, "success", (short) 0, methodMetadata.getResultCodec(), result); return new ThriftFrame(sequenceId, outputTransport.getBuffer(), ImmutableMap.of(), transport, protocol, supportOutOfOrderResponse); } finally { outputTransport.release(); } } private static ThriftFrame writeExceptionResponse(ChannelHandlerContext context, MethodMetadata methodMetadata, Transport transport, Protocol protocol, int sequenceId, boolean supportOutOfOrderResponse, Throwable exception) throws Exception { Optional<Short> exceptionId = methodMetadata.getExceptionId(exception.getClass()); if (exceptionId.isPresent()) { TChannelBufferOutputTransport outputTransport = new TChannelBufferOutputTransport(context.alloc()); try { TProtocolWriter protocolWriter = protocol.createProtocol(outputTransport); writeResponse(methodMetadata.getName(), protocolWriter, sequenceId, "exception", exceptionId.get(), methodMetadata.getExceptionCodecs().get(exceptionId.get()), exception); return new ThriftFrame(sequenceId, outputTransport.getBuffer(), ImmutableMap.of(), transport, protocol, supportOutOfOrderResponse); } finally { outputTransport.release(); } } TApplicationException.Type type = INTERNAL_ERROR; if (exception instanceof TApplicationException) { type = ((TApplicationException) exception).getType().orElse(INTERNAL_ERROR); } return writeApplicationException(context, methodMetadata.getName(), transport, protocol, sequenceId, supportOutOfOrderResponse, type, "Internal error processing " + methodMetadata.getName() + ": " + exception.getMessage(), exception); } private static ThriftFrame writeApplicationException(ChannelHandlerContext context, String methodName, Transport transport, Protocol protocol, int sequenceId, boolean supportOutOfOrderResponse, TApplicationException.Type errorCode, String errorMessage, Throwable cause) throws Exception { TApplicationException applicationException = new TApplicationException(errorCode, errorMessage); if (cause != null) { applicationException.initCause(cause); } TChannelBufferOutputTransport outputTransport = new TChannelBufferOutputTransport(context.alloc()); try { TProtocolWriter protocolWriter = protocol.createProtocol(outputTransport); protocolWriter.writeMessageBegin(new TMessage(methodName, EXCEPTION, sequenceId)); ExceptionWriter.writeTApplicationException(applicationException, protocolWriter); protocolWriter.writeMessageEnd(); return new ThriftFrame(sequenceId, outputTransport.getBuffer(), ImmutableMap.of(), transport, protocol, supportOutOfOrderResponse); } finally { outputTransport.release(); } } private static void writeResponse(String methodName, TProtocolWriter protocolWriter, int sequenceId, String responseFieldName, short responseFieldId, ThriftCodec<Object> responseCodec, Object result) throws Exception { protocolWriter.writeMessageBegin(new TMessage(methodName, REPLY, sequenceId)); ProtocolWriter writer = new ProtocolWriter(protocolWriter); writer.writeStructBegin(methodName + "_result"); writer.writeField(responseFieldName, responseFieldId, responseCodec, result); writer.writeStructEnd(); protocolWriter.writeMessageEnd(); } /* * There is no good way of detecting connection closed exception * * This implementation is a simplified version of the implementation proposed * in Netty: io.netty.handler.ssl.SslHandler#exceptionCaught * * This implementation ony checks a message with the regex, and doesn't do any * more sophisticated matching, as the regex works in most of the cases. */ private boolean isConnectionClosed(Throwable t) { if (t instanceof IOException) { return CONNECTION_CLOSED_MESSAGE.matcher(nullToEmpty(t.getMessage())).matches(); } return false; } }