Java tutorial
/* * Copyright 2018 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.grpc.alts.internal; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import io.grpc.Attributes; import io.grpc.Channel; import io.grpc.Grpc; import io.grpc.InternalChannelz; import io.grpc.ManagedChannel; import io.grpc.SecurityLevel; import io.grpc.alts.internal.AltsProtocolNegotiator.LazyChannel; import io.grpc.alts.internal.AltsProtocolNegotiator.ServerAltsProtocolNegotiator; import io.grpc.alts.internal.TsiFrameProtector.Consumer; import io.grpc.alts.internal.TsiPeer.Property; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.NettyChannelBuilder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http2.DefaultHttp2Connection; import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; import io.netty.handler.codec.http2.DefaultHttp2FrameReader; import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2ConnectionDecoder; import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2FrameReader; import io.netty.handler.codec.http2.Http2FrameWriter; import io.netty.handler.codec.http2.Http2Settings; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Tests for {@link AltsProtocolNegotiator}. */ @RunWith(JUnit4.class) @SuppressWarnings("FutureReturnValueIgnored") public class AltsProtocolNegotiatorTest { private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler(); private final List<ReferenceCounted> references = new ArrayList<>(); private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>(); private EmbeddedChannel channel; private Throwable caughtException; private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList()); private AltsAuthContext mockedAltsContext = new AltsAuthContext(HandshakerResult.newBuilder() .setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()).build()); private final TsiHandshaker mockHandshaker = new DelegatingTsiHandshaker( FakeTsiHandshaker.newFakeHandshakerServer()) { @Override public TsiPeer extractPeer() { return mockedTsiPeer; } @Override public Object extractPeerObject() { return mockedAltsContext; } }; private final NettyTsiHandshaker serverHandshaker = new NettyTsiHandshaker(mockHandshaker); @Before public void setup() throws Exception { ChannelHandler uncaughtExceptionHandler = new ChannelDuplexHandler() { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { caughtException = cause; super.exceptionCaught(ctx, cause); ctx.close(); } }; TsiHandshakerFactory handshakerFactory = new DelegatingTsiHandshakerFactory( FakeTsiHandshaker.clientHandshakerFactory()) { @Override public TsiHandshaker newHandshaker(String authority) { return new DelegatingTsiHandshaker(super.newHandshaker(authority)) { @Override public TsiPeer extractPeer() throws GeneralSecurityException { return mockedTsiPeer; } @Override public Object extractPeerObject() throws GeneralSecurityException { return mockedAltsContext; } }; } }; ManagedChannel fakeChannel = NettyChannelBuilder.forTarget("localhost:8080").build(); ObjectPool<Channel> fakeChannelPool = new FixedObjectPool<Channel>(fakeChannel); LazyChannel lazyFakeChannel = new LazyChannel(fakeChannelPool); ChannelHandler altsServerHandler = new ServerAltsProtocolNegotiator(handshakerFactory, lazyFakeChannel) .newHandler(grpcHandler); // On real server, WBAEH fires default ProtocolNegotiationEvent. KickNH provides this behavior. ChannelHandler handler = new KickNegotiationHandler(altsServerHandler); channel = new EmbeddedChannel(uncaughtExceptionHandler, handler); } @After public void teardown() throws Exception { if (channel != null) { @SuppressWarnings("unused") // go/futurereturn-lsc Future<?> possiblyIgnoredError = channel.close(); } for (ReferenceCounted reference : references) { ReferenceCountUtil.safeRelease(reference); } } @Test public void handshakeShouldBeSuccessful() throws Exception { doHandshake(); } @Test @SuppressWarnings("unchecked") // List cast public void protectShouldRoundtrip() throws Exception { doHandshake(); // Write the message 1 character at a time. The message should be buffered // and not interfere with the handshake. final AtomicInteger writeCount = new AtomicInteger(); String message = "hello"; for (int ix = 0; ix < message.length(); ++ix) { ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8); @SuppressWarnings("unused") // go/futurereturn-lsc Future<?> possiblyIgnoredError = channel.write(in).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { writeCount.incrementAndGet(); } } }); } channel.flush(); // Capture the protected data written to the wire. assertEquals(1, channel.outboundMessages().size()); ByteBuf protectedData = channel.readOutbound(); assertEquals(message.length(), writeCount.get()); // Read the protected message at the server and verify it matches the original message. TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc()); List<ByteBuf> unprotected = new ArrayList<>(); serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc()); // We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it // a settings frame is written (and an HTTP2 preface). This is hard coded into Netty, so we // have to remove it here. See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}. int settingsFrameLength = 9; CompositeByteBuf unprotectedAll = new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected); ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length()); assertEquals(message, unprotectedData.toString(UTF_8)); // Protect the same message at the server. final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>(); serverProtector.protectFlush(Collections.singletonList(unprotectedData), new Consumer<ByteBuf>() { @Override public void accept(ByteBuf buf) { newlyProtectedData.set(buf); } }, channel.alloc()); // Read the protected message at the client and verify that it matches the original message. channel.writeInbound(newlyProtectedData.get()); assertEquals(1, channel.inboundMessages().size()); assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8)); } @Test public void unprotectLargeIncomingFrame() throws Exception { // We use a server frameprotector with twice the standard frame size. int serverFrameSize = 4096 * 2; // This should fit into one frame. byte[] unprotectedBytes = new byte[serverFrameSize - 500]; Arrays.fill(unprotectedBytes, (byte) 7); ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes); unprotectedData.writerIndex(unprotectedBytes.length); // Perform handshake. doHandshake(); // Protect the message on the server. TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc()); serverProtector.protectFlush(Collections.singletonList(unprotectedData), new Consumer<ByteBuf>() { @Override public void accept(ByteBuf buf) { channel.writeInbound(buf); } }, channel.alloc()); channel.flushInbound(); // Read the protected message at the client and verify that it matches the original message. assertEquals(1, channel.inboundMessages().size()); ByteBuf receivedData1 = channel.readInbound(); int receivedLen1 = receivedData1.readableBytes(); byte[] receivedBytes = new byte[receivedLen1]; receivedData1.readBytes(receivedBytes, 0, receivedLen1); assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length); assertThat(unprotectedBytes).isEqualTo(receivedBytes); } @Test public void flushShouldFailAllPromises() throws Exception { doHandshake(); channel.pipeline().addFirst(new ChannelDuplexHandler() { @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { throw new Exception("Fake exception"); } }); // Write the message 1 character at a time. String message = "hello"; final AtomicInteger failures = new AtomicInteger(); for (int ix = 0; ix < message.length(); ++ix) { ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8); @SuppressWarnings("unused") // go/futurereturn-lsc Future<?> possiblyIgnoredError = channel.write(in).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { failures.incrementAndGet(); } } }); } channel.flush(); // Verify that the promises fail. assertEquals(message.length(), failures.get()); } @Test public void doNotFlushEmptyBuffer() throws Exception { doHandshake(); assertEquals(1, protectors.size()); InterceptingProtector protector = protectors.poll(); String message = "hello"; ByteBuf in = Unpooled.copiedBuffer(message, UTF_8); assertEquals(0, protector.flushes.get()); Future<?> done = channel.write(in); channel.flush(); done.get(5, TimeUnit.SECONDS); assertEquals(1, protector.flushes.get()); done = channel.write(Unpooled.EMPTY_BUFFER); channel.flush(); done.get(5, TimeUnit.SECONDS); assertEquals(1, protector.flushes.get()); } @Test public void peerPropagated() throws Exception { doHandshake(); assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY)).isEqualTo(mockedTsiPeer); assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY)).isEqualTo(mockedAltsContext); assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString()).isEqualTo("embedded"); assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR).toString()).isEqualTo("embedded"); assertThat(grpcHandler.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL)) .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY); } private void doHandshake() throws Exception { // Capture the client frame and add to the server. assertEquals(1, channel.outboundMessages().size()); ByteBuf clientFrame = channel.readOutbound(); assertTrue(serverHandshaker.processBytesFromPeer(clientFrame)); // Get the server response handshake frames. ByteBuf serverFrame = channel.alloc().buffer(); serverHandshaker.getBytesToSendToPeer(serverFrame); channel.writeInbound(serverFrame); // Capture the next client frame and add to the server. assertEquals(1, channel.outboundMessages().size()); clientFrame = channel.readOutbound(); assertTrue(serverHandshaker.processBytesFromPeer(clientFrame)); // Get the server response handshake frames. serverFrame = channel.alloc().buffer(); serverHandshaker.getBytesToSendToPeer(serverFrame); channel.writeInbound(serverFrame); // Ensure that both sides have confirmed that the handshake has completed. assertFalse(serverHandshaker.isInProgress()); if (caughtException != null) { throw new RuntimeException(caughtException); } assertNotNull(grpcHandler.attrs); } private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() { // Netty Boilerplate. We don't really need any of this, but there is a tight coupling // between an Http2ConnectionHandler and its dependencies. Http2Connection connection = new DefaultHttp2Connection(true); Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); Http2FrameReader frameReader = new DefaultHttp2FrameReader(false); DefaultHttp2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); DefaultHttp2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings()); } private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { private Attributes attrs; private CapturingGrpcHttp2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings) { super(null, decoder, encoder, initialSettings); } @Override public void handleProtocolNegotiationCompleted(Attributes attrs, @SuppressWarnings("UnusedVariable") InternalChannelz.Security securityInfo) { // If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler channel.pipeline().remove(this); this.attrs = attrs; } } private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory { private TsiHandshakerFactory delegate; DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) { this.delegate = delegate; } @Override public TsiHandshaker newHandshaker(String authority) { return delegate.newHandshaker(authority); } } private class DelegatingTsiHandshaker implements TsiHandshaker { private final TsiHandshaker delegate; DelegatingTsiHandshaker(TsiHandshaker delegate) { this.delegate = delegate; } @Override public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException { delegate.getBytesToSendToPeer(bytes); } @Override public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException { return delegate.processBytesFromPeer(bytes); } @Override public boolean isInProgress() { return delegate.isInProgress(); } @Override public TsiPeer extractPeer() throws GeneralSecurityException { return delegate.extractPeer(); } @Override public Object extractPeerObject() throws GeneralSecurityException { return delegate.extractPeerObject(); } @Override public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) { InterceptingProtector protector = new InterceptingProtector(delegate.createFrameProtector(alloc)); protectors.add(protector); return protector; } @Override public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) { InterceptingProtector protector = new InterceptingProtector( delegate.createFrameProtector(maxFrameSize, alloc)); protectors.add(protector); return protector; } } private static class InterceptingProtector implements TsiFrameProtector { private final TsiFrameProtector delegate; final AtomicInteger flushes = new AtomicInteger(); InterceptingProtector(TsiFrameProtector delegate) { this.delegate = delegate; } @Override public void protectFlush(List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc) throws GeneralSecurityException { flushes.incrementAndGet(); delegate.protectFlush(unprotectedBufs, ctxWrite, alloc); } @Override public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc) throws GeneralSecurityException { delegate.unprotect(in, out, alloc); } @Override public void destroy() { delegate.destroy(); } } /** Kicks off negotiation of the server. */ private static final class KickNegotiationHandler extends ChannelInboundHandlerAdapter { private final ChannelHandler next; KickNegotiationHandler(ChannelHandler next) { this.next = checkNotNull(next, "next"); } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { super.handlerAdded(ctx); ctx.pipeline().replace(ctx.name(), /*newName= */ null, next); ctx.pipeline().fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); } } }