com.linecorp.armeria.server.thrift.THttp2Client.java Source code

Java tutorial

Introduction

Here is the source code for com.linecorp.armeria.server.thrift.THttp2Client.java

Source

/*
 * Copyright 2015 LINE Corporation
 *
 * LINE Corporation licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.linecorp.armeria.server.thrift;

import static org.junit.Assert.assertTrue;

import java.net.URI;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLException;

import org.apache.thrift.transport.TMemoryBuffer;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;

import com.linecorp.armeria.internal.http.Http1ClientCodec;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientUpgradeHandler;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DelegatingDecompressorFrameListener;
import io.netty.handler.codec.http2.Http2ClientUpgradeCodec;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2SecurityUtil;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.HttpConversionUtil;
import io.netty.handler.codec.http2.HttpConversionUtil.ExtensionHeaderNames;
import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandler;
import io.netty.handler.codec.http2.HttpToHttp2ConnectionHandlerBuilder;
import io.netty.handler.codec.http2.InboundHttp2ToHttpAdapterBuilder;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.ApplicationProtocolConfig.Protocol;
import io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFailureBehavior;
import io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior;
import io.netty.handler.ssl.ApplicationProtocolNames;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.concurrent.Promise;

/**
 * An extremely simple Thrift-over-HTTP/2 client which sends and receives a single Thrift request/response
 * per connection.
 */
final class THttp2Client extends TTransport {

    private final EventLoopGroup group = new NioEventLoopGroup(1);
    private final SslContext sslCtx;
    private final URI uri;
    private final String host;
    private final int port;
    private final String path;

    private TMemoryInputTransport in;
    private final TMemoryBuffer out = new TMemoryBuffer(128);

    THttp2Client(String uriStr) throws TTransportException {
        uri = URI.create(uriStr);

        int port;
        switch (uri.getScheme()) {
        case "http":
            port = uri.getPort();
            if (port < 0) {
                port = 80;
            }
            sslCtx = null;
            break;
        case "https":
            port = uri.getPort();
            if (port < 0) {
                port = 443;
            }

            try {
                sslCtx = SslContextBuilder.forClient()
                        .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE)
                        .trustManager(InsecureTrustManagerFactory.INSTANCE)
                        .applicationProtocolConfig(new ApplicationProtocolConfig(Protocol.ALPN,
                                // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers.
                                SelectorFailureBehavior.NO_ADVERTISE,
                                // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers.
                                SelectedListenerFailureBehavior.ACCEPT, ApplicationProtocolNames.HTTP_2))
                        .build();
            } catch (SSLException e) {
                throw new TTransportException(TTransportException.UNKNOWN, e);
            }
            break;
        default:
            throw new IllegalArgumentException("unknown scheme: " + uri.getScheme());
        }

        String host = uri.getHost();
        if (host == null) {
            throw new IllegalArgumentException("host not specified: " + uriStr);
        }

        String path = uri.getPath();
        if (path == null) {
            throw new IllegalArgumentException("path not specified: " + uriStr);
        }

        this.host = host;
        this.port = port;
        this.path = path;
    }

    @Override
    public boolean isOpen() {
        return true;
    }

    @Override
    public void open() {
    }

    @Override
    public void close() {
        group.shutdownGracefully();
    }

    @Override
    public int read(byte[] buf, int off, int len) throws TTransportException {
        return in.read(buf, off, len);
    }

    @Override
    public int readAll(byte[] buf, int off, int len) throws TTransportException {
        return in.readAll(buf, off, len);
    }

    @Override
    public byte[] getBuffer() {
        return in.getBuffer();
    }

    @Override
    public int getBufferPosition() {
        return in.getBufferPosition();
    }

    @Override
    public int getBytesRemainingInBuffer() {
        return in.getBytesRemainingInBuffer();
    }

    @Override
    public void consumeBuffer(int len) {
        in.consumeBuffer(len);
    }

    @Override
    public void write(byte[] buf, int off, int len) {
        out.write(buf, off, len);
    }

    @Override
    public void flush() throws TTransportException {
        THttp2ClientInitializer initHandler = new THttp2ClientInitializer();

        Bootstrap b = new Bootstrap();
        b.group(group);
        b.channel(NioSocketChannel.class);
        b.handler(initHandler);

        Channel ch = null;
        try {
            ch = b.connect(host, port).syncUninterruptibly().channel();
            THttp2ClientHandler handler = initHandler.THttp2ClientHandler;

            // Wait until HTTP/2 upgrade is finished.
            assertTrue(handler.settingsPromise.await(5, TimeUnit.SECONDS));
            handler.settingsPromise.get();

            // Send a Thrift request.
            FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, path,
                    Unpooled.wrappedBuffer(out.getArray(), 0, out.length()));
            request.headers().add(HttpHeaderNames.HOST, host);
            request.headers().set(ExtensionHeaderNames.SCHEME.text(), uri.getScheme());
            ch.writeAndFlush(request).sync();

            // Wait until the Thrift response is received.
            assertTrue(handler.responsePromise.await(5, TimeUnit.SECONDS));
            ByteBuf response = handler.responsePromise.get();

            // Pass the received Thrift response to the Thrift client.
            final byte[] array = new byte[response.readableBytes()];
            response.readBytes(array);
            in = new TMemoryInputTransport(array);
            response.release();
        } catch (Exception e) {
            throw new TTransportException(TTransportException.UNKNOWN, e);
        } finally {
            if (ch != null) {
                ch.close();
            }
        }
    }

    private final class THttp2ClientInitializer extends ChannelInitializer<SocketChannel> {

        THttp2ClientHandler THttp2ClientHandler;

        @Override
        public void initChannel(SocketChannel ch) throws Exception {
            final ChannelPipeline p = ch.pipeline();
            final Http2Connection conn = new DefaultHttp2Connection(false);
            final HttpToHttp2ConnectionHandler connHandler = new HttpToHttp2ConnectionHandlerBuilder()
                    .connection(conn)
                    .frameListener(
                            new DelegatingDecompressorFrameListener(conn, new InboundHttp2ToHttpAdapterBuilder(conn)
                                    .maxContentLength(Integer.MAX_VALUE).propagateSettings(true).build()))
                    .build();

            THttp2ClientHandler = new THttp2ClientHandler(ch.eventLoop());

            if (sslCtx != null) {
                p.addLast(sslCtx.newHandler(p.channel().alloc()));
                p.addLast(connHandler);
                configureEndOfPipeline(p);
            } else {
                Http1ClientCodec sourceCodec = new Http1ClientCodec();
                HttpClientUpgradeHandler upgradeHandler = new HttpClientUpgradeHandler(sourceCodec,
                        new Http2ClientUpgradeCodec(connHandler), 65536);

                p.addLast(sourceCodec, upgradeHandler, new UpgradeRequestHandler());
            }
        }

        private void configureEndOfPipeline(ChannelPipeline p) {
            p.addLast(THttp2ClientHandler);
        }

        /**
         * A handler that triggers the cleartext upgrade to HTTP/2 by sending an initial HTTP request.
         */
        private final class UpgradeRequestHandler extends ChannelInboundHandlerAdapter {
            @Override
            public void channelActive(ChannelHandlerContext ctx) throws Exception {
                DefaultFullHttpRequest upgradeRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1,
                        HttpMethod.HEAD, "/");
                ctx.writeAndFlush(upgradeRequest);

                ctx.fireChannelActive();

                // Done with this handler, remove it from the pipeline.
                ctx.pipeline().remove(this);

                configureEndOfPipeline(ctx.pipeline());
            }
        }
    }

    static final class THttp2ClientHandler extends SimpleChannelInboundHandler<Object> {

        final Promise<Void> settingsPromise;
        final Promise<ByteBuf> responsePromise;

        THttp2ClientHandler(EventLoop loop) {
            settingsPromise = loop.newPromise();
            responsePromise = loop.newPromise();
        }

        @Override
        protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
            if (msg instanceof Http2Settings) {
                settingsPromise.setSuccess(null);
                return;
            }

            if (msg instanceof FullHttpResponse) {
                FullHttpResponse res = (FullHttpResponse) msg;
                Integer streamId = res.headers().getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text());
                if (streamId == null) {
                    responsePromise.tryFailure(new AssertionError("message without stream ID: " + msg));
                    return;
                }

                if (streamId == 1) {
                    // Response to the upgrade request, which is OK to ignore.
                    return;
                }

                if (streamId != 3) {
                    responsePromise.tryFailure(new AssertionError("unexpected stream ID: " + msg));
                    return;
                }

                responsePromise.setSuccess(res.content().retain());
                return;
            }

            throw new IllegalStateException("unexpected message type: " + msg.getClass().getName());
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            responsePromise.tryFailure(cause);
        }
    }
}