org.apache.flink.runtime.query.netty.KvStateClientTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.runtime.query.netty.KvStateClientTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.query.netty;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.query.KvStateID;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.query.KvStateServerAddress;
import org.apache.flink.runtime.query.netty.message.KvStateRequest;
import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.AbstractStateBackend;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.util.NetUtils;
import org.junit.AfterClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.Await;
import scala.concurrent.Future;
import scala.concurrent.duration.Deadline;
import scala.concurrent.duration.FiniteDuration;

import java.net.ConnectException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class KvStateClientTest {

    private static final Logger LOG = LoggerFactory.getLogger(KvStateClientTest.class);

    // Thread pool for client bootstrap (shared between tests)
    private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup();

    private final static FiniteDuration TEST_TIMEOUT = new FiniteDuration(100, TimeUnit.SECONDS);

    @AfterClass
    public static void tearDown() throws Exception {
        if (NIO_GROUP != null) {
            NIO_GROUP.shutdownGracefully();
        }
    }

    /**
     * Tests simple queries, of which half succeed and half fail.
     */
    @Test
    public void testSimpleRequests() throws Exception {
        Deadline deadline = TEST_TIMEOUT.fromNow();
        AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();

        KvStateClient client = null;
        Channel serverChannel = null;

        try {
            client = new KvStateClient(1, stats);

            // Random result
            final byte[] expected = new byte[1024];
            ThreadLocalRandom.current().nextBytes(expected);

            final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>();
            final AtomicReference<Channel> channel = new AtomicReference<>();

            serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
                @Override
                public void channelActive(ChannelHandlerContext ctx) throws Exception {
                    channel.set(ctx.channel());
                }

                @Override
                public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                    received.add((ByteBuf) msg);
                }
            });

            KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);

            List<Future<byte[]>> futures = new ArrayList<>();

            int numQueries = 1024;

            for (int i = 0; i < numQueries; i++) {
                futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
            }

            // Respond to messages
            Exception testException = new RuntimeException("Expected test Exception");

            for (int i = 0; i < numQueries; i++) {
                ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
                assertNotNull("Receive timed out", buf);

                Channel ch = channel.get();
                assertNotNull("Channel not active", ch);

                assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf));
                KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf);

                buf.release();

                if (i % 2 == 0) {
                    ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult(serverChannel.alloc(),
                            request.getRequestId(), expected);

                    ch.writeAndFlush(response);
                } else {
                    ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestFailure(
                            serverChannel.alloc(), request.getRequestId(), testException);

                    ch.writeAndFlush(response);
                }
            }

            for (int i = 0; i < numQueries; i++) {
                if (i % 2 == 0) {
                    byte[] serializedResult = Await.result(futures.get(i), deadline.timeLeft());
                    assertArrayEquals(expected, serializedResult);
                } else {
                    try {
                        Await.result(futures.get(i), deadline.timeLeft());
                        fail("Did not throw expected Exception");
                    } catch (RuntimeException ignored) {
                        // Expected
                    }
                }
            }

            assertEquals(numQueries, stats.getNumRequests());
            int expectedRequests = numQueries / 2;

            // Counts can take some time to propagate
            while (deadline.hasTimeLeft()
                    && (stats.getNumSuccessful() != expectedRequests || stats.getNumFailed() != expectedRequests)) {
                Thread.sleep(100);
            }

            assertEquals(expectedRequests, stats.getNumSuccessful());
            assertEquals(expectedRequests, stats.getNumFailed());
        } finally {
            if (client != null) {
                client.shutDown();
            }

            if (serverChannel != null) {
                serverChannel.close();
            }

            assertEquals("Channel leak", 0, stats.getNumConnections());
        }
    }

    /**
     * Tests that a request to an unavailable host is failed with ConnectException.
     */
    @Test
    public void testRequestUnavailableHost() throws Exception {
        Deadline deadline = TEST_TIMEOUT.fromNow();
        AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
        KvStateClient client = null;

        try {
            client = new KvStateClient(1, stats);

            int availablePort = NetUtils.getAvailablePort();

            KvStateServerAddress serverAddress = new KvStateServerAddress(InetAddress.getLocalHost(),
                    availablePort);

            Future<byte[]> future = client.getKvState(serverAddress, new KvStateID(), new byte[0]);

            try {
                Await.result(future, deadline.timeLeft());
                fail("Did not throw expected ConnectException");
            } catch (ConnectException ignored) {
                // Expected
            }
        } finally {
            if (client != null) {
                client.shutDown();
            }

            assertEquals("Channel leak", 0, stats.getNumConnections());
        }
    }

    /**
     * Multiple threads concurrently fire queries.
     */
    @Test
    public void testConcurrentQueries() throws Exception {
        Deadline deadline = TEST_TIMEOUT.fromNow();
        AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();

        ExecutorService executor = null;
        KvStateClient client = null;
        Channel serverChannel = null;

        final byte[] serializedResult = new byte[1024];
        ThreadLocalRandom.current().nextBytes(serializedResult);

        try {
            int numQueryTasks = 4;
            final int numQueriesPerTask = 1024;

            executor = Executors.newFixedThreadPool(numQueryTasks);

            client = new KvStateClient(1, stats);

            serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
                @Override
                public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                    ByteBuf buf = (ByteBuf) msg;
                    assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf));
                    KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf);

                    buf.release();

                    ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult(ctx.alloc(),
                            request.getRequestId(), serializedResult);

                    ctx.channel().writeAndFlush(response);
                }
            });

            final KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);

            final KvStateClient finalClient = client;
            Callable<List<Future<byte[]>>> queryTask = new Callable<List<Future<byte[]>>>() {
                @Override
                public List<Future<byte[]>> call() throws Exception {
                    List<Future<byte[]>> results = new ArrayList<>(numQueriesPerTask);

                    for (int i = 0; i < numQueriesPerTask; i++) {
                        results.add(finalClient.getKvState(serverAddress, new KvStateID(), new byte[0]));
                    }

                    return results;
                }
            };

            // Submit query tasks
            List<java.util.concurrent.Future<List<Future<byte[]>>>> futures = new ArrayList<>();
            for (int i = 0; i < numQueryTasks; i++) {
                futures.add(executor.submit(queryTask));
            }

            // Verify results
            for (java.util.concurrent.Future<List<Future<byte[]>>> future : futures) {
                List<Future<byte[]>> results = future.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
                for (Future<byte[]> result : results) {
                    byte[] actual = Await.result(result, deadline.timeLeft());
                    assertArrayEquals(serializedResult, actual);
                }
            }

            int totalQueries = numQueryTasks * numQueriesPerTask;

            // Counts can take some time to propagate
            while (deadline.hasTimeLeft() && stats.getNumSuccessful() != totalQueries) {
                Thread.sleep(100);
            }

            assertEquals(totalQueries, stats.getNumRequests());
            assertEquals(totalQueries, stats.getNumSuccessful());
        } finally {
            if (executor != null) {
                executor.shutdown();
            }

            if (serverChannel != null) {
                serverChannel.close();
            }

            if (client != null) {
                client.shutDown();
            }

            assertEquals("Channel leak", 0, stats.getNumConnections());
        }
    }

    /**
     * Tests that a server failure closes the connection and removes it from
     * the established connections.
     */
    @Test
    public void testFailureClosesChannel() throws Exception {
        Deadline deadline = TEST_TIMEOUT.fromNow();
        AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();

        KvStateClient client = null;
        Channel serverChannel = null;

        try {
            client = new KvStateClient(1, stats);

            final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>();
            final AtomicReference<Channel> channel = new AtomicReference<>();

            serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
                @Override
                public void channelActive(ChannelHandlerContext ctx) throws Exception {
                    channel.set(ctx.channel());
                }

                @Override
                public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                    received.add((ByteBuf) msg);
                }
            });

            KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);

            // Requests
            List<Future<byte[]>> futures = new ArrayList<>();
            futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
            futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));

            ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
            assertNotNull("Receive timed out", buf);
            buf.release();

            buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
            assertNotNull("Receive timed out", buf);
            buf.release();

            assertEquals(1, stats.getNumConnections());

            Channel ch = channel.get();
            assertNotNull("Channel not active", ch);

            // Respond with failure
            ch.writeAndFlush(KvStateRequestSerializer.serializeServerFailure(serverChannel.alloc(),
                    new RuntimeException("Expected test server failure")));

            try {
                Await.result(futures.remove(0), deadline.timeLeft());
                fail("Did not throw expected server failure");
            } catch (RuntimeException ignored) {
                // Expected
            }

            try {
                Await.result(futures.remove(0), deadline.timeLeft());
                fail("Did not throw expected server failure");
            } catch (RuntimeException ignored) {
                // Expected
            }

            assertEquals(0, stats.getNumConnections());

            // Counts can take some time to propagate
            while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 || stats.getNumFailed() != 2)) {
                Thread.sleep(100);
            }

            assertEquals(2, stats.getNumRequests());
            assertEquals(0, stats.getNumSuccessful());
            assertEquals(2, stats.getNumFailed());
        } finally {
            if (client != null) {
                client.shutDown();
            }

            if (serverChannel != null) {
                serverChannel.close();
            }

            assertEquals("Channel leak", 0, stats.getNumConnections());
        }
    }

    /**
     * Tests that a server channel close, closes the connection and removes it
     * from the established connections.
     */
    @Test
    public void testServerClosesChannel() throws Exception {
        Deadline deadline = TEST_TIMEOUT.fromNow();
        AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();

        KvStateClient client = null;
        Channel serverChannel = null;

        try {
            client = new KvStateClient(1, stats);

            final AtomicBoolean received = new AtomicBoolean();
            final AtomicReference<Channel> channel = new AtomicReference<>();

            serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
                @Override
                public void channelActive(ChannelHandlerContext ctx) throws Exception {
                    channel.set(ctx.channel());
                }

                @Override
                public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                    received.set(true);
                }
            });

            KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);

            // Requests
            Future<byte[]> future = client.getKvState(serverAddress, new KvStateID(), new byte[0]);

            while (!received.get() && deadline.hasTimeLeft()) {
                Thread.sleep(50);
            }
            assertTrue("Receive timed out", received.get());

            assertEquals(1, stats.getNumConnections());

            channel.get().close().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);

            try {
                Await.result(future, deadline.timeLeft());
                fail("Did not throw expected server failure");
            } catch (ClosedChannelException ignored) {
                // Expected
            }

            assertEquals(0, stats.getNumConnections());

            // Counts can take some time to propagate
            while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 || stats.getNumFailed() != 1)) {
                Thread.sleep(100);
            }

            assertEquals(1, stats.getNumRequests());
            assertEquals(0, stats.getNumSuccessful());
            assertEquals(1, stats.getNumFailed());
        } finally {
            if (client != null) {
                client.shutDown();
            }

            if (serverChannel != null) {
                serverChannel.close();
            }

            assertEquals("Channel leak", 0, stats.getNumConnections());
        }
    }

    /**
     * Tests multiple clients querying multiple servers until 100k queries have
     * been processed. At this point, the client is shut down and its verified
     * that all ongoing requests are failed.
     */
    @Test
    public void testClientServerIntegration() throws Exception {
        // Config
        final int numServers = 2;
        final int numServerEventLoopThreads = 2;
        final int numServerQueryThreads = 2;

        final int numClientEventLoopThreads = 4;
        final int numClientsTasks = 8;

        final int batchSize = 16;

        final int numKeyGroups = 1;

        AbstractStateBackend abstractBackend = new MemoryStateBackend();
        KvStateRegistry dummyRegistry = new KvStateRegistry();
        DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
        dummyEnv.setKvStateRegistry(dummyRegistry);

        AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(dummyEnv, new JobID(),
                "test_op", IntSerializer.INSTANCE, numKeyGroups, new KeyGroupRange(0, 0),
                dummyRegistry.createTaskRegistry(new JobID(), new JobVertexID()));

        final FiniteDuration timeout = new FiniteDuration(10, TimeUnit.SECONDS);

        AtomicKvStateRequestStats clientStats = new AtomicKvStateRequestStats();

        KvStateClient client = null;
        ExecutorService clientTaskExecutor = null;
        final KvStateServer[] server = new KvStateServer[numServers];

        try {
            client = new KvStateClient(numClientEventLoopThreads, clientStats);
            clientTaskExecutor = Executors.newFixedThreadPool(numClientsTasks);

            // Create state
            ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE);
            desc.setQueryable("any");

            // Create servers
            KvStateRegistry[] registry = new KvStateRegistry[numServers];
            AtomicKvStateRequestStats[] serverStats = new AtomicKvStateRequestStats[numServers];
            final KvStateID[] ids = new KvStateID[numServers];

            for (int i = 0; i < numServers; i++) {
                registry[i] = new KvStateRegistry();
                serverStats[i] = new AtomicKvStateRequestStats();
                server[i] = new KvStateServer(InetAddress.getLocalHost(), 0, numServerEventLoopThreads,
                        numServerQueryThreads, registry[i], serverStats[i]);

                server[i].start();

                backend.setCurrentKey(1010 + i);

                // Value per server
                ValueState<Integer> state = backend.getPartitionedState(VoidNamespace.INSTANCE,
                        VoidNamespaceSerializer.INSTANCE, desc);

                state.update(201 + i);

                // we know it must be a KvStat but this is not exposed to the user via State
                InternalKvState<?> kvState = (InternalKvState<?>) state;

                // Register KvState (one state instance for all server)
                ids[i] = registry[i].registerKvState(new JobID(), new JobVertexID(), new KeyGroupRange(0, 0), "any",
                        kvState);
            }

            final KvStateClient finalClient = client;
            Callable<Void> queryTask = new Callable<Void>() {
                @Override
                public Void call() throws Exception {
                    while (true) {
                        if (Thread.interrupted()) {
                            throw new InterruptedException();
                        }

                        // Random server permutation
                        List<Integer> random = new ArrayList<>();
                        for (int j = 0; j < batchSize; j++) {
                            random.add(j);
                        }
                        Collections.shuffle(random);

                        // Dispatch queries
                        List<Future<byte[]>> futures = new ArrayList<>(batchSize);

                        for (int j = 0; j < batchSize; j++) {
                            int targetServer = random.get(j) % numServers;

                            byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
                                    1010 + targetServer, IntSerializer.INSTANCE, VoidNamespace.INSTANCE,
                                    VoidNamespaceSerializer.INSTANCE);

                            futures.add(finalClient.getKvState(server[targetServer].getAddress(), ids[targetServer],
                                    serializedKeyAndNamespace));
                        }

                        // Verify results
                        for (int j = 0; j < batchSize; j++) {
                            int targetServer = random.get(j) % numServers;

                            Future<byte[]> future = futures.get(j);
                            byte[] buf = Await.result(future, timeout);
                            int value = KvStateRequestSerializer.deserializeValue(buf, IntSerializer.INSTANCE);
                            assertEquals(201 + targetServer, value);
                        }
                    }
                }
            };

            // Submit tasks
            List<java.util.concurrent.Future<Void>> taskFutures = new ArrayList<>();
            for (int i = 0; i < numClientsTasks; i++) {
                taskFutures.add(clientTaskExecutor.submit(queryTask));
            }

            long numRequests;
            while ((numRequests = clientStats.getNumRequests()) < 100_000) {
                Thread.sleep(100);
                LOG.info("Number of requests {}/100_000", numRequests);
            }

            // Shut down
            client.shutDown();

            for (java.util.concurrent.Future<Void> future : taskFutures) {
                try {
                    future.get();
                    fail("Did not throw expected Exception after shut down");
                } catch (ExecutionException t) {
                    if (t.getCause() instanceof ClosedChannelException
                            || t.getCause() instanceof IllegalStateException) {
                        // Expected
                    } else {
                        t.printStackTrace();
                        fail("Failed with unexpected Exception type: " + t.getClass().getName());
                    }
                }
            }

            assertEquals("Connection leak (client)", 0, clientStats.getNumConnections());
            for (int i = 0; i < numServers; i++) {
                boolean success = false;
                int numRetries = 0;
                while (!success) {
                    try {
                        assertEquals("Connection leak (server)", 0, serverStats[i].getNumConnections());
                        success = true;
                    } catch (Throwable t) {
                        if (numRetries < 10) {
                            LOG.info("Retrying connection leak check (server)");
                            Thread.sleep((numRetries + 1) * 50);
                            numRetries++;
                        } else {
                            throw t;
                        }
                    }
                }
            }
        } finally {
            if (client != null) {
                client.shutDown();
            }

            for (int i = 0; i < numServers; i++) {
                if (server[i] != null) {
                    server[i].shutDown();
                }
            }

            if (clientTaskExecutor != null) {
                clientTaskExecutor.shutdown();
            }
        }
    }

    // ------------------------------------------------------------------------

    private Channel createServerChannel(final ChannelHandler... handlers)
            throws UnknownHostException, InterruptedException {
        ServerBootstrap bootstrap = new ServerBootstrap()
                // Bind address and port
                .localAddress(InetAddress.getLocalHost(), 0)
                // NIO server channels
                .group(NIO_GROUP).channel(NioServerSocketChannel.class)
                // See initializer for pipeline details
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) throws Exception {
                        ch.pipeline().addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
                                .addLast(handlers);
                    }
                });

        return bootstrap.bind().sync().channel();
    }

    private KvStateServerAddress getKvStateServerAddress(Channel serverChannel) {
        InetSocketAddress localAddress = (InetSocketAddress) serverChannel.localAddress();

        return new KvStateServerAddress(localAddress.getAddress(), localAddress.getPort());
    }
}