io.grpc.netty.ProtocolNegotiatorsTest.java Source code

Java tutorial

Introduction

Here is the source code for io.grpc.netty.ProtocolNegotiatorsTest.java

Source

/*
 * Copyright 2015 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.netty;

import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

import io.grpc.Attributes;
import io.grpc.Grpc;
import io.grpc.InternalChannelz.Security;
import io.grpc.SecurityLevel;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator;
import io.grpc.netty.ProtocolNegotiators.HostPort;
import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpServerUpgradeHandler;
import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodec;
import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodecFactory;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2ServerUpgradeCodec;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.proxy.ProxyConnectException;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.File;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.LogRecord;
import java.util.logging.Logger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.DisableOnDebug;
import org.junit.rules.ExpectedException;
import org.junit.rules.TestRule;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class ProtocolNegotiatorsTest {
    private static final Runnable NOOP_RUNNABLE = new Runnable() {
        @Override
        public void run() {
        }
    };

    private static final int TIMEOUT_SECONDS = 60;
    @Rule
    public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS));
    @Rule
    public final ExpectedException thrown = ExpectedException.none();

    private final EventLoopGroup group = new DefaultEventLoop();
    private Channel chan;
    private Channel server;

    private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler();

    private EmbeddedChannel channel = new EmbeddedChannel();
    private ChannelPipeline pipeline = channel.pipeline();
    private SslContext sslContext;
    private SSLEngine engine;
    private ChannelHandlerContext channelHandlerCtx;

    @Before
    public void setUp() throws Exception {
        File serverCert = TestUtils.loadCert("server1.pem");
        File key = TestUtils.loadCert("server1.key");
        sslContext = GrpcSslContexts.forServer(serverCert, key)
                .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
        engine = SSLContext.getDefault().createSSLEngine();
        engine.setUseClientMode(true);
    }

    @After
    public void tearDown() {
        if (server != null) {
            server.close();
        }
        if (chan != null) {
            chan.close();
        }
        group.shutdownGracefully();
    }

    @Test
    public void waitUntilActiveHandler_handlerAdded() throws Exception {
        final CountDownLatch latch = new CountDownLatch(1);

        final WaitUntilActiveHandler handler = new WaitUntilActiveHandler(new ChannelHandlerAdapter() {
            @Override
            public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
                assertTrue(ctx.channel().isActive());
                latch.countDown();
                super.handlerAdded(ctx);
            }
        });

        ChannelHandler lateAddingHandler = new ChannelInboundHandlerAdapter() {
            @Override
            public void channelActive(ChannelHandlerContext ctx) throws Exception {
                ctx.pipeline().addLast(handler);
                ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
                // do not propagate channelActive().
            }
        };

        LocalAddress addr = new LocalAddress("local");
        ChannelFuture cf = new Bootstrap().channel(LocalChannel.class).handler(lateAddingHandler).group(group)
                .register();
        chan = cf.channel();
        ChannelFuture sf = new ServerBootstrap().channel(LocalServerChannel.class)
                .childHandler(new ChannelHandlerAdapter() {
                }).group(group).bind(addr);
        server = sf.channel();
        sf.sync();

        assertEquals(1, latch.getCount());

        chan.connect(addr).sync();
        assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
        assertNull(chan.pipeline().context(WaitUntilActiveHandler.class));
    }

    @Test
    public void waitUntilActiveHandler_channelActive() throws Exception {
        final CountDownLatch latch = new CountDownLatch(1);
        WaitUntilActiveHandler handler = new WaitUntilActiveHandler(new ChannelHandlerAdapter() {
            @Override
            public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
                assertTrue(ctx.channel().isActive());
                latch.countDown();
                super.handlerAdded(ctx);
            }
        });

        LocalAddress addr = new LocalAddress("local");
        ChannelFuture cf = new Bootstrap().channel(LocalChannel.class).handler(handler).group(group).register();
        chan = cf.channel();
        ChannelFuture sf = new ServerBootstrap().channel(LocalServerChannel.class)
                .childHandler(new ChannelHandlerAdapter() {
                }).group(group).bind(addr);
        server = sf.channel();
        sf.sync();

        assertEquals(1, latch.getCount());

        chan.connect(addr).sync();
        chan.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
        assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
        assertNull(chan.pipeline().context(WaitUntilActiveHandler.class));
    }

    @Test
    public void tlsHandler_failsOnNullEngine() throws Exception {
        thrown.expect(NullPointerException.class);
        thrown.expectMessage("ssl");

        Object unused = ProtocolNegotiators.serverTls(null);
    }

    @Test
    public void tlsHandler_handlerAddedAddsSslHandler() throws Exception {
        ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);

        pipeline.addLast(handler);

        assertTrue(pipeline.first() instanceof SslHandler);
    }

    @Test
    public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception {
        ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
        pipeline.addLast(handler);
        channelHandlerCtx = pipeline.context(handler);
        Object nonSslEvent = new Object();

        pipeline.fireUserEventTriggered(nonSslEvent);

        // A non ssl event should not cause the grpcHandler to be in the pipeline yet.
        ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
        assertNull(grpcHandlerCtx);
    }

    @Test
    public void tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception {
        SslHandler badSslHandler = new SslHandler(engine, false) {
            @Override
            public String applicationProtocol() {
                return "badprotocol";
            }
        };

        ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
        pipeline.addLast(handler);

        final AtomicReference<Throwable> error = new AtomicReference<>();
        ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
            @Override
            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                error.set(cause);
            }
        };

        pipeline.addLast(errorCapture);

        pipeline.replace(SslHandler.class, null, badSslHandler);
        channelHandlerCtx = pipeline.context(handler);
        Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;

        pipeline.fireUserEventTriggered(sslEvent);

        // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
        assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
        ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
        assertNull(grpcHandlerCtx);
    }

    @Test
    public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception {
        ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
        pipeline.addLast(handler);
        channelHandlerCtx = pipeline.context(handler);
        Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad"));

        final AtomicReference<Throwable> error = new AtomicReference<>();
        ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
            @Override
            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                error.set(cause);
            }
        };

        pipeline.addLast(errorCapture);

        pipeline.fireUserEventTriggered(sslEvent);

        // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
        assertThat(error.get()).hasMessageThat().contains("bad");
        ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
        assertNull(grpcHandlerCtx);
    }

    @Test
    public void tlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception {
        SslHandler goodSslHandler = new SslHandler(engine, false) {
            @Override
            public String applicationProtocol() {
                return "h2";
            }
        };

        ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
        pipeline.addLast(handler);

        pipeline.replace(SslHandler.class, null, goodSslHandler);
        channelHandlerCtx = pipeline.context(handler);
        Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;

        pipeline.fireUserEventTriggered(sslEvent);

        assertTrue(channel.isOpen());
        ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
        assertNotNull(grpcHandlerCtx);
    }

    @Test
    public void tlsHandler_userEventTriggeredSslEvent_supportedProtocolGrpcExp() throws Exception {
        SslHandler goodSslHandler = new SslHandler(engine, false) {
            @Override
            public String applicationProtocol() {
                return "grpc-exp";
            }
        };

        ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
        pipeline.addLast(handler);

        pipeline.replace(SslHandler.class, null, goodSslHandler);
        channelHandlerCtx = pipeline.context(handler);
        Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;

        pipeline.fireUserEventTriggered(sslEvent);

        assertTrue(channel.isOpen());
        ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
        assertNotNull(grpcHandlerCtx);
    }

    @Test
    public void engineLog() {
        ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext);
        pipeline.addLast(handler);
        channelHandlerCtx = pipeline.context(handler);

        Logger logger = Logger.getLogger(ProtocolNegotiators.class.getName());
        Filter oldFilter = logger.getFilter();
        try {
            logger.setFilter(new Filter() {
                @Override
                public boolean isLoggable(LogRecord record) {
                    // We still want to the log method to be exercised, just not printed to stderr.
                    return false;
                }
            });

            ProtocolNegotiators.logSslEngineDetails(Level.INFO, channelHandlerCtx, "message", new Exception("bad"));
        } finally {
            logger.setFilter(oldFilter);
        }
    }

    @Test
    public void tls_failsOnNullSslContext() {
        thrown.expect(NullPointerException.class);

        Object unused = ProtocolNegotiators.tls(null);
    }

    @Test
    public void tls_hostAndPort() {
        HostPort hostPort = ProtocolNegotiators.parseAuthority("authority:1234");

        assertEquals("authority", hostPort.host);
        assertEquals(1234, hostPort.port);
    }

    @Test
    public void tls_host() {
        HostPort hostPort = ProtocolNegotiators.parseAuthority("[::1]");

        assertEquals("[::1]", hostPort.host);
        assertEquals(-1, hostPort.port);
    }

    @Test
    public void tls_invalidHost() throws SSLException {
        HostPort hostPort = ProtocolNegotiators.parseAuthority("bad_host:1234");

        // Even though it looks like a port, we treat it as part of the authority, since the host is
        // invalid.
        assertEquals("bad_host:1234", hostPort.host);
        assertEquals(-1, hostPort.port);
    }

    @Test
    public void httpProxy_nullAddressNpe() throws Exception {
        thrown.expect(NullPointerException.class);
        Object unused = ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext());
    }

    @Test
    public void httpProxy_nullNegotiatorNpe() throws Exception {
        thrown.expect(NullPointerException.class);
        Object unused = ProtocolNegotiators.httpProxy(InetSocketAddress.createUnresolved("localhost", 80), "user",
                "pass", null);
    }

    @Test
    public void httpProxy_nullUserPassNoException() throws Exception {
        assertNotNull(ProtocolNegotiators.httpProxy(InetSocketAddress.createUnresolved("localhost", 80), null, null,
                ProtocolNegotiators.plaintext()));
    }

    @Test
    public void httpProxy_completes() throws Exception {
        DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
        // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
        // the channel is already active.
        LocalAddress proxy = new LocalAddress("httpProxy_completes");
        SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);

        ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
        Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
                .childHandler(mockHandler).bind(proxy).sync().channel();

        ProtocolNegotiator nego = ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
        // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation,
        // mocking the behavior using KickStartHandler.
        ChannelHandler handler = new KickStartHandler(
                nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()));
        Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler).register().sync()
                .channel();
        pipeline = channel.pipeline();
        // Wait for initialization to complete
        channel.eventLoop().submit(NOOP_RUNNABLE).sync();
        channel.connect(host).sync();
        serverChannel.close();
        ArgumentCaptor<ChannelHandlerContext> contextCaptor = ArgumentCaptor.forClass(ChannelHandlerContext.class);
        Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
        ChannelHandlerContext serverContext = contextCaptor.getValue();

        final String golden = "isThisThingOn?";
        ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));

        // Wait for sending initial request to complete
        channel.eventLoop().submit(NOOP_RUNNABLE).sync();
        ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
        Mockito.verify(mockHandler).channelRead(ArgumentMatchers.<ChannelHandlerContext>any(),
                objectCaptor.capture());
        ByteBuf b = (ByteBuf) objectCaptor.getValue();
        String request = b.toString(UTF_8);
        b.release();
        assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n"));
        assertTrue("No CONNECT: " + request, request.startsWith("CONNECT specialHost:314 "));
        assertTrue("No host header: " + request, request.contains("host: specialHost:314"));

        assertFalse(negotiationFuture.isDone());
        serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync();
        negotiationFuture.sync();

        channel.eventLoop().submit(NOOP_RUNNABLE).sync();
        objectCaptor = ArgumentCaptor.forClass(Object.class);
        Mockito.verify(mockHandler, times(2)).channelRead(ArgumentMatchers.<ChannelHandlerContext>any(),
                objectCaptor.capture());
        b = (ByteBuf) objectCaptor.getAllValues().get(1);
        // If we were using the real grpcHandler, this would have been the HTTP/2 preface
        String preface = b.toString(UTF_8);
        b.release();
        assertEquals(golden, preface);

        channel.close();
    }

    @Test
    public void httpProxy_500() throws Exception {
        DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
        // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
        // the channel is already active.
        LocalAddress proxy = new LocalAddress("httpProxy_500");
        SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);

        ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
        Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
                .childHandler(mockHandler).bind(proxy).sync().channel();

        ProtocolNegotiator nego = ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
        // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation,
        // mocking the behavior using KickStartHandler.
        ChannelHandler handler = new KickStartHandler(
                nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()));
        Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler).register().sync()
                .channel();
        pipeline = channel.pipeline();
        // Wait for initialization to complete
        channel.eventLoop().submit(NOOP_RUNNABLE).sync();
        channel.connect(host).sync();
        serverChannel.close();
        ArgumentCaptor<ChannelHandlerContext> contextCaptor = ArgumentCaptor.forClass(ChannelHandlerContext.class);
        Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
        ChannelHandlerContext serverContext = contextCaptor.getValue();

        final String golden = "isThisThingOn?";
        ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));

        // Wait for sending initial request to complete
        channel.eventLoop().submit(NOOP_RUNNABLE).sync();
        ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
        Mockito.verify(mockHandler).channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
        ByteBuf request = (ByteBuf) objectCaptor.getValue();
        request.release();

        assertFalse(negotiationFuture.isDone());
        String response = "HTTP/1.1 500 OMG\r\nContent-Length: 4\r\n\r\noops";
        serverContext.writeAndFlush(bb(response, serverContext.channel())).sync();
        thrown.expect(ProxyConnectException.class);
        try {
            negotiationFuture.sync();
        } finally {
            channel.close();
        }
    }

    @Test
    public void waitUntilActiveHandler_firesNegotiation() throws Exception {
        EventLoopGroup elg = new DefaultEventLoopGroup(1);
        SocketAddress addr = new LocalAddress("addr");
        final AtomicReference<Object> event = new AtomicReference<>();
        ChannelHandler next = new ChannelInboundHandlerAdapter() {
            @Override
            public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
                event.set(evt);
                ctx.close();
            }
        };
        Channel s = new ServerBootstrap().childHandler(new ChannelInboundHandlerAdapter()).group(elg)
                .channel(LocalServerChannel.class).bind(addr).sync().channel();
        Channel c = new Bootstrap().handler(new WaitUntilActiveHandler(next)).channel(LocalChannel.class)
                .group(group).connect(addr).sync().channel();
        c.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
        SocketAddress localAddr = c.localAddress();
        ProtocolNegotiationEvent expectedEvent = ProtocolNegotiationEvent.DEFAULT
                .withAttributes(Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddr)
                        .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, addr)
                        .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE).build());

        c.closeFuture().sync();
        assertThat(event.get()).isInstanceOf(ProtocolNegotiationEvent.class);
        ProtocolNegotiationEvent actual = (ProtocolNegotiationEvent) event.get();
        assertThat(actual).isEqualTo(expectedEvent);

        s.close();
        elg.shutdownGracefully();
    }

    @Test
    public void clientTlsHandler_firesNegotiation() throws Exception {
        SelfSignedCertificate cert = new SelfSignedCertificate("authority");
        SslContext clientSslContext = GrpcSslContexts
                .configure(SslContextBuilder.forClient().trustManager(cert.cert())).build();
        SslContext serverSslContext = GrpcSslContexts
                .configure(SslContextBuilder.forServer(cert.key(), cert.cert())).build();
        FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();

        ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext);
        WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh));

        SocketAddress addr = new LocalAddress("addr");

        ChannelHandler sh = ProtocolNegotiators.serverTls(serverSslContext)
                .newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler());
        WriteBufferingAndExceptionHandler serverWbaeh = new WriteBufferingAndExceptionHandler(sh);
        Channel s = new ServerBootstrap().childHandler(serverWbaeh).group(group).channel(LocalServerChannel.class)
                .bind(addr).sync().channel();
        Channel c = new Bootstrap().handler(clientWbaeh).channel(LocalChannel.class).group(group).register().sync()
                .channel();
        ChannelFuture write = c.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
        c.connect(addr).sync();
        write.sync();

        boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS);
        if (!completed) {
            assertTrue("failed to negotiated", write.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
            // sync should fail if we are in this block.
            write.sync();
            throw new AssertionError("neither wrote nor negotiated");
        }
        c.close();
        s.close();

        assertThat(gh.securityInfo).isNotNull();
        assertThat(gh.securityInfo.tls).isNotNull();
        assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL)).isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
        assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_SSL_SESSION)).isInstanceOf(SSLSession.class);
        // This is not part of the ClientTls negotiation, but shows that the negotiation event happens
        // in the right order.
        assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr);
    }

    @Test
    public void plaintextUpgradeNegotiator() throws Exception {
        LocalAddress addr = new LocalAddress("plaintextUpgradeNegotiator");
        UpgradeCodecFactory ucf = new UpgradeCodecFactory() {

            @Override
            public UpgradeCodec newUpgradeCodec(CharSequence protocol) {
                return new Http2ServerUpgradeCodec(FakeGrpcHttp2ConnectionHandler.newHandler());
            }
        };
        final HttpServerCodec serverCodec = new HttpServerCodec();
        final HttpServerUpgradeHandler serverUpgradeHandler = new HttpServerUpgradeHandler(serverCodec, ucf);
        Channel serverChannel = new ServerBootstrap().group(group).channel(LocalServerChannel.class)
                .childHandler(new ChannelInitializer<Channel>() {

                    @Override
                    protected void initChannel(Channel ch) throws Exception {
                        ch.pipeline().addLast(serverCodec, serverUpgradeHandler);
                    }
                }).bind(addr).sync().channel();

        FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
        ProtocolNegotiator nego = ProtocolNegotiators.plaintextUpgrade();
        ChannelHandler ch = nego.newHandler(gh);
        WriteBufferingAndExceptionHandler wbaeh = new WriteBufferingAndExceptionHandler(ch);

        Channel channel = new Bootstrap().group(group).channel(LocalChannel.class).handler(wbaeh).register().sync()
                .channel();

        ChannelFuture write = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
        channel.connect(serverChannel.localAddress());

        boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS);
        if (!completed) {
            assertTrue("failed to negotiated", write.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
            // sync should fail if we are in this block.
            write.sync();
            throw new AssertionError("neither wrote nor negotiated");
        }

        channel.close().sync();
        serverChannel.close();

        assertThat(gh.securityInfo).isNull();
        assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL)).isEqualTo(SecurityLevel.NONE);
        assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr);
    }

    private static class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {

        static FakeGrpcHttp2ConnectionHandler noopHandler() {
            return newHandler(true);
        }

        static FakeGrpcHttp2ConnectionHandler newHandler() {
            return newHandler(false);
        }

        private static FakeGrpcHttp2ConnectionHandler newHandler(boolean noop) {
            DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
            DefaultHttp2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(conn,
                    new DefaultHttp2FrameWriter());
            DefaultHttp2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(conn, encoder,
                    new DefaultHttp2FrameReader());
            Http2Settings settings = new Http2Settings();
            return new FakeGrpcHttp2ConnectionHandler(/*channelUnused=*/ null, decoder, encoder, settings, noop);
        }

        private final boolean noop;
        private Attributes attrs;
        private Security securityInfo;
        private final CountDownLatch negotiated = new CountDownLatch(1);
        private ChannelHandlerContext ctx;

        FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused, Http2ConnectionDecoder decoder,
                Http2ConnectionEncoder encoder, Http2Settings initialSettings, boolean noop) {
            super(channelUnused, decoder, encoder, initialSettings);
            this.noop = noop;
        }

        @Override
        public void handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo) {
            checkNotNull(ctx, "handleProtocolNegotiationCompleted cannot be called before handlerAdded");
            super.handleProtocolNegotiationCompleted(attrs, securityInfo);
            this.attrs = attrs;
            this.securityInfo = securityInfo;
            // Add a temp handler that verifies first message is a NOOP_MESSAGE
            ctx.pipeline().addBefore(ctx.name(), null, new ChannelOutboundHandlerAdapter() {
                @Override
                public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
                    checkState(msg == NettyClientHandler.NOOP_MESSAGE, "First message should be NOOP_MESSAGE");
                    promise.trySuccess();
                    ctx.pipeline().remove(this);
                }
            });
            NettyClientHandler.writeBufferingAndRemove(ctx.channel());
            negotiated.countDown();
        }

        @Override
        public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
            if (noop) {
                ctx.pipeline().remove(ctx.name());
            } else {
                super.handlerAdded(ctx);
            }
            this.ctx = ctx;
        }

        @Override
        public String getAuthority() {
            return "authority";
        }
    }

    private static ByteBuf bb(String s, Channel c) {
        return ByteBufUtil.writeUtf8(c.alloc(), s);
    }

    private static final class KickStartHandler extends ChannelDuplexHandler {

        private final ChannelHandler next;

        public KickStartHandler(ChannelHandler next) {
            this.next = checkNotNull(next, "next");
        }

        @Override
        public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
            ctx.pipeline().replace(ctx.name(), null, next);
            ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
        }
    }
}