org.apache.flink.runtime.query.netty.KvStateServerTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.runtime.query.netty.KvStateServerTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.query.netty;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.query.KvStateServerAddress;
import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.AbstractStateBackend;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.junit.AfterClass;
import org.junit.Test;

import java.net.InetAddress;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class KvStateServerTest {

    // Thread pool for client bootstrap (shared between tests)
    private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup();

    private final static int TIMEOUT_MILLIS = 10000;

    @AfterClass
    public static void tearDown() throws Exception {
        if (NIO_GROUP != null) {
            NIO_GROUP.shutdownGracefully();
        }
    }

    /**
     * Tests a simple successful query via a SocketChannel.
     */
    @Test
    public void testSimpleRequest() throws Exception {
        KvStateServer server = null;
        Bootstrap bootstrap = null;
        try {
            KvStateRegistry registry = new KvStateRegistry();
            KvStateRequestStats stats = new AtomicKvStateRequestStats();

            server = new KvStateServer(InetAddress.getLocalHost(), 0, 1, 1, registry, stats);
            server.start();

            KvStateServerAddress serverAddress = server.getAddress();
            int numKeyGroups = 1;
            AbstractStateBackend abstractBackend = new MemoryStateBackend();
            DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
            dummyEnv.setKvStateRegistry(registry);
            AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(dummyEnv,
                    new JobID(), "test_op", IntSerializer.INSTANCE, numKeyGroups, new KeyGroupRange(0, 0),
                    registry.createTaskRegistry(new JobID(), new JobVertexID()));

            final KvStateServerHandlerTest.TestRegistryListener registryListener = new KvStateServerHandlerTest.TestRegistryListener();

            registry.registerListener(registryListener);

            ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE);
            desc.setQueryable("vanilla");

            ValueState<Integer> state = backend.getPartitionedState(VoidNamespace.INSTANCE,
                    VoidNamespaceSerializer.INSTANCE, desc);

            // Update KvState
            int expectedValue = 712828289;

            int key = 99812822;
            backend.setCurrentKey(key);
            state.update(expectedValue);

            // Request
            byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(key,
                    IntSerializer.INSTANCE, VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE);

            // Connect to the server
            final BlockingQueue<ByteBuf> responses = new LinkedBlockingQueue<>();
            bootstrap = createBootstrap(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4),
                    new ChannelInboundHandlerAdapter() {
                        @Override
                        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                            responses.add((ByteBuf) msg);
                        }
                    });

            Channel channel = bootstrap.connect(serverAddress.getHost(), serverAddress.getPort()).sync().channel();

            long requestId = Integer.MAX_VALUE + 182828L;

            assertTrue(registryListener.registrationName.equals("vanilla"));
            ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(channel.alloc(), requestId,
                    registryListener.kvStateId, serializedKeyAndNamespace);

            channel.writeAndFlush(request);

            ByteBuf buf = responses.poll(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);

            assertEquals(KvStateRequestType.REQUEST_RESULT, KvStateRequestSerializer.deserializeHeader(buf));
            KvStateRequestResult response = KvStateRequestSerializer.deserializeKvStateRequestResult(buf);

            assertEquals(requestId, response.getRequestId());
            int actualValue = KvStateRequestSerializer.deserializeValue(response.getSerializedResult(),
                    IntSerializer.INSTANCE);
            assertEquals(expectedValue, actualValue);
        } finally {
            if (server != null) {
                server.shutDown();
            }

            if (bootstrap != null) {
                EventLoopGroup group = bootstrap.group();
                if (group != null) {
                    group.shutdownGracefully();
                }
            }
        }
    }

    /**
     * Creates a client bootstrap.
     */
    private Bootstrap createBootstrap(final ChannelHandler... handlers) {
        return new Bootstrap().group(NIO_GROUP).channel(NioSocketChannel.class)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) throws Exception {
                        ch.pipeline().addLast(handlers);
                    }
                });
    }

}