io.grpc.alts.internal.AltsTsiTest.java Source code

Java tutorial

Introduction

Here is the source code for io.grpc.alts.internal.AltsTsiTest.java

Source

/*
 * Copyright 2018 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.alts.internal;

import static org.junit.Assert.assertEquals;

import com.google.common.testing.GcFinalization;
import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef;
import io.grpc.alts.internal.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Unit tests for {@link AltsTsiHandshaker}. */
@RunWith(JUnit4.class)
public class AltsTsiTest {
    private static final int OVERHEAD = FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();

    private final List<ReferenceCounted> references = new ArrayList<>();
    private AltsHandshakerClient client;
    private AltsHandshakerClient server;
    private final RegisterRef ref = new RegisterRef() {
        @Override
        public ByteBuf register(ByteBuf buf) {
            if (buf != null) {
                references.add(buf);
            }
            return buf;
        }
    };

    @Before
    public void setUp() throws Exception {
        ResourceLeakDetector.setLevel(Level.PARANOID);
        // Use MockAltsHandshakerStub for all the tests.
        AltsHandshakerOptions handshakerOptions = new AltsHandshakerOptions(null);
        MockAltsHandshakerStub clientStub = new MockAltsHandshakerStub();
        MockAltsHandshakerStub serverStub = new MockAltsHandshakerStub();
        client = new AltsHandshakerClient(clientStub, handshakerOptions);
        server = new AltsHandshakerClient(serverStub, handshakerOptions);
    }

    @After
    public void tearDown() {
        for (ReferenceCounted reference : references) {
            reference.release();
        }
        references.clear();
        // Increase our chances to detect ByteBuf leaks.
        GcFinalization.awaitFullGc();
    }

    private Handshakers newHandshakers() {
        TsiHandshaker clientHandshaker = new AltsTsiHandshaker(true, client);
        TsiHandshaker serverHandshaker = new AltsTsiHandshaker(false, server);
        return new Handshakers(clientHandshaker, serverHandshaker);
    }

    @Test
    public void verifyHandshakePeer() throws Exception {
        Handshakers handshakers = newHandshakers();
        TsiTest.performHandshake(TsiTest.getDefaultTransportBufferSize(), handshakers);
        TsiPeer clientPeer = handshakers.getClient().extractPeer();
        assertEquals(1, clientPeer.getProperties().size());
        assertEquals(MockAltsHandshakerResp.getTestPeerAccount(),
                clientPeer.getProperty("service_account").getValue());
        TsiPeer serverPeer = handshakers.getServer().extractPeer();
        assertEquals(1, serverPeer.getProperties().size());
        assertEquals(MockAltsHandshakerResp.getTestPeerAccount(),
                serverPeer.getProperty("service_account").getValue());
    }

    @Test
    public void handshake() throws GeneralSecurityException {
        TsiTest.handshakeTest(newHandshakers());
    }

    @Test
    public void handshakeSmallBuffer() throws GeneralSecurityException {
        TsiTest.handshakeSmallBufferTest(newHandshakers());
    }

    @Test
    public void pingPong() throws GeneralSecurityException {
        TsiTest.pingPongTest(newHandshakers(), ref);
    }

    @Test
    public void pingPongExactFrameSize() throws GeneralSecurityException {
        TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
    }

    @Test
    public void pingPongSmallBuffer() throws GeneralSecurityException {
        TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
    }

    @Test
    public void pingPongSmallFrame() throws GeneralSecurityException {
        TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
    }

    @Test
    public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
        TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
    }

    @Test
    public void corruptedCounter() throws GeneralSecurityException {
        TsiTest.corruptedCounterTest(newHandshakers(), ref);
    }

    @Test
    public void corruptedCiphertext() throws GeneralSecurityException {
        TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
    }

    @Test
    public void corruptedTag() throws GeneralSecurityException {
        TsiTest.corruptedTagTest(newHandshakers(), ref);
    }

    @Test
    public void reflectedCiphertext() throws GeneralSecurityException {
        TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
    }

    private static class MockAltsHandshakerStub extends AltsHandshakerStub {
        private boolean started = false;

        @Override
        public HandshakerResp send(HandshakerReq req) {
            if (started) {
                // Expect handshake next message.
                if (req.getReqOneofCase().getNumber() != 3) {
                    return MockAltsHandshakerResp.getErrorResponse();
                }
                return MockAltsHandshakerResp.getFinishedResponse(req.getNext().getInBytes().size());
            } else {
                List<String> recordProtocols;
                int bytesConsumed = 0;
                switch (req.getReqOneofCase().getNumber()) {
                case 1:
                    recordProtocols = req.getClientStart().getRecordProtocolsList();
                    break;
                case 2:
                    recordProtocols = req.getServerStart().getHandshakeParametersMap()
                            .get(HandshakeProtocol.ALTS.getNumber()).getRecordProtocolsList();
                    bytesConsumed = req.getServerStart().getInBytes().size();
                    break;
                default:
                    return MockAltsHandshakerResp.getErrorResponse();
                }
                if (recordProtocols.isEmpty()) {
                    return MockAltsHandshakerResp.getErrorResponse();
                }
                started = true;
                return MockAltsHandshakerResp.getOkResponse(bytesConsumed);
            }
        }

        @Override
        public void close() {
        }
    }
}