org.apache.hadoop.ipc.TestAsyncIPC.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.ipc.TestAsyncIPC.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.ipc;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.TestIPC.CallInfo;
import org.apache.hadoop.ipc.TestIPC.TestServer;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.concurrent.AsyncGetFuture;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;

public class TestAsyncIPC {

    private static Configuration conf;
    private static final Log LOG = LogFactory.getLog(TestAsyncIPC.class);

    static <T extends Writable> AsyncGetFuture<T, IOException> getAsyncRpcResponseFuture() {
        return (AsyncGetFuture<T, IOException>) new AsyncGetFuture<>(Client.getAsyncRpcResponse());
    }

    @Before
    public void setupConf() {
        conf = new Configuration();
        conf.setInt(CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_KEY, 10000);
        Client.setPingInterval(conf, TestIPC.PING_INTERVAL);
        // set asynchronous mode for main thread
        Client.setAsynchronousMode(true);
    }

    static class AsyncCaller extends Thread {
        private Client client;
        private InetSocketAddress server;
        private int count;
        private boolean failed;
        Map<Integer, Future<LongWritable>> returnFutures = new HashMap<Integer, Future<LongWritable>>();
        Map<Integer, Long> expectedValues = new HashMap<Integer, Long>();

        public AsyncCaller(Client client, InetSocketAddress server, int count) {
            this.client = client;
            this.server = server;
            this.count = count;
            // set asynchronous mode, since AsyncCaller extends Thread
            Client.setAsynchronousMode(true);
        }

        @Override
        public void run() {
            // in case Thread#Start is called, which will spawn new thread
            Client.setAsynchronousMode(true);
            for (int i = 0; i < count; i++) {
                try {
                    final long param = TestIPC.RANDOM.nextLong();
                    TestIPC.call(client, param, server, conf);
                    Future<LongWritable> returnFuture = getAsyncRpcResponseFuture();
                    returnFutures.put(i, returnFuture);
                    expectedValues.put(i, param);
                } catch (Exception e) {
                    failed = true;
                    throw new RuntimeException(e);
                }
            }
        }

        void assertReturnValues() throws InterruptedException, ExecutionException {
            for (int i = 0; i < count; i++) {
                LongWritable value = returnFutures.get(i).get();
                Assert.assertEquals("call" + i + " failed.", expectedValues.get(i).longValue(), value.get());
            }
            Assert.assertFalse(failed);
        }

        void assertReturnValues(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException {
            final boolean[] checked = new boolean[count];
            for (boolean done = false; !done;) {
                done = true;
                for (int i = 0; i < count; i++) {
                    if (checked[i]) {
                        continue;
                    } else {
                        done = false;
                    }

                    final LongWritable value;
                    try {
                        value = returnFutures.get(i).get(timeout, unit);
                    } catch (TimeoutException e) {
                        LOG.info("call" + i + " caught ", e);
                        continue;
                    }

                    Assert.assertEquals("call" + i + " failed.", expectedValues.get(i).longValue(), value.get());
                    checked[i] = true;
                }
            }
            Assert.assertFalse(failed);
        }
    }

    static class AsyncLimitlCaller extends Thread {
        private Client client;
        private InetSocketAddress server;
        private int count;
        private boolean failed;
        Map<Integer, Future<LongWritable>> returnFutures = new HashMap<Integer, Future<LongWritable>>();
        Map<Integer, Long> expectedValues = new HashMap<Integer, Long>();
        int start = 0, end = 0;

        int getStart() {
            return start;
        }

        int getEnd() {
            return end;
        }

        int getCount() {
            return count;
        }

        public AsyncLimitlCaller(Client client, InetSocketAddress server, int count) {
            this(0, client, server, count);
        }

        final int callerId;

        public AsyncLimitlCaller(int callerId, Client client, InetSocketAddress server, int count) {
            this.client = client;
            this.server = server;
            this.count = count;
            // set asynchronous mode, since AsyncLimitlCaller extends Thread
            Client.setAsynchronousMode(true);
            this.callerId = callerId;
        }

        @Override
        public void run() {
            // in case Thread#Start is called, which will spawn new thread
            Client.setAsynchronousMode(true);
            for (int i = 0; i < count; i++) {
                try {
                    final long param = TestIPC.RANDOM.nextLong();
                    runCall(i, param);
                } catch (Exception e) {
                    LOG.fatal(String.format("Caller-%d Call-%d caught: %s", callerId, i,
                            StringUtils.stringifyException(e)));
                    failed = true;
                }
            }
        }

        private void runCall(final int idx, final long param)
                throws InterruptedException, ExecutionException, IOException {
            for (;;) {
                try {
                    doCall(idx, param);
                    return;
                } catch (AsyncCallLimitExceededException e) {
                    /**
                     * reached limit of async calls, fetch results of finished async calls
                     * to let follow-on calls go
                     */
                    start = end;
                    end = idx;
                    waitForReturnValues(start, end);
                }
            }
        }

        private void doCall(final int idx, final long param) throws IOException {
            TestIPC.call(client, param, server, conf);
            Future<LongWritable> returnFuture = getAsyncRpcResponseFuture();
            returnFutures.put(idx, returnFuture);
            expectedValues.put(idx, param);
        }

        private void waitForReturnValues(final int start, final int end)
                throws InterruptedException, ExecutionException {
            for (int i = start; i < end; i++) {
                LongWritable value = returnFutures.get(i).get();
                if (expectedValues.get(i) != value.get()) {
                    LOG.fatal(String.format("Caller-%d Call-%d failed!", callerId, i));
                    failed = true;
                    break;
                }
            }
        }
    }

    @Test(timeout = 60000)
    public void testAsyncCall() throws IOException, InterruptedException, ExecutionException {
        internalTestAsyncCall(3, false, 2, 5, 100);
        internalTestAsyncCall(3, true, 2, 5, 10);
    }

    @Test(timeout = 60000)
    public void testAsyncCallLimit() throws IOException, InterruptedException, ExecutionException {
        internalTestAsyncCallLimit(100, false, 5, 10, 500);
    }

    public void internalTestAsyncCall(int handlerCount, boolean handlerSleep, int clientCount, int callerCount,
            int callCount) throws IOException, InterruptedException, ExecutionException {
        Server server = new TestIPC.TestServer(handlerCount, handlerSleep, conf);
        InetSocketAddress addr = NetUtils.getConnectAddress(server);
        server.start();

        Client[] clients = new Client[clientCount];
        for (int i = 0; i < clientCount; i++) {
            clients[i] = new Client(LongWritable.class, conf);
        }

        AsyncCaller[] callers = new AsyncCaller[callerCount];
        for (int i = 0; i < callerCount; i++) {
            callers[i] = new AsyncCaller(clients[i % clientCount], addr, callCount);
            callers[i].start();
        }
        for (int i = 0; i < callerCount; i++) {
            callers[i].join();
            callers[i].assertReturnValues();
        }
        for (int i = 0; i < clientCount; i++) {
            clients[i].stop();
        }
        server.stop();
    }

    @Test(timeout = 60000)
    public void testCallGetReturnRpcResponseMultipleTimes()
            throws IOException, InterruptedException, ExecutionException {
        int handlerCount = 10, callCount = 100;
        Server server = new TestIPC.TestServer(handlerCount, false, conf);
        InetSocketAddress addr = NetUtils.getConnectAddress(server);
        server.start();
        final Client client = new Client(LongWritable.class, conf);

        int asyncCallCount = client.getAsyncCallCount();

        try {
            AsyncCaller caller = new AsyncCaller(client, addr, callCount);
            caller.run();
            caller.assertReturnValues();
            caller.assertReturnValues();
            caller.assertReturnValues();
            Assert.assertEquals(asyncCallCount, client.getAsyncCallCount());
        } finally {
            client.stop();
            server.stop();
        }
    }

    @Test(timeout = 60000)
    public void testFutureGetWithTimeout() throws IOException, InterruptedException, ExecutionException {
        //    GenericTestUtils.setLogLevel(AsyncGetFuture.LOG, Level.ALL);
        final Server server = new TestIPC.TestServer(10, true, conf);
        final InetSocketAddress addr = NetUtils.getConnectAddress(server);
        server.start();

        final Client client = new Client(LongWritable.class, conf);

        try {
            final AsyncCaller caller = new AsyncCaller(client, addr, 10);
            caller.run();
            caller.assertReturnValues(10, TimeUnit.MILLISECONDS);
        } finally {
            client.stop();
            server.stop();
        }
    }

    public void internalTestAsyncCallLimit(int handlerCount, boolean handlerSleep, int clientCount, int callerCount,
            int callCount) throws IOException, InterruptedException, ExecutionException {
        Configuration conf = new Configuration();
        conf.setInt(CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_KEY, 100);
        Client.setPingInterval(conf, TestIPC.PING_INTERVAL);

        Server server = new TestIPC.TestServer(handlerCount, handlerSleep, conf);
        InetSocketAddress addr = NetUtils.getConnectAddress(server);
        server.start();

        Client[] clients = new Client[clientCount];
        for (int i = 0; i < clientCount; i++) {
            clients[i] = new Client(LongWritable.class, conf);
        }

        AsyncLimitlCaller[] callers = new AsyncLimitlCaller[callerCount];
        for (int i = 0; i < callerCount; i++) {
            callers[i] = new AsyncLimitlCaller(i, clients[i % clientCount], addr, callCount);
            callers[i].start();
        }
        for (int i = 0; i < callerCount; i++) {
            callers[i].join();
            callers[i].waitForReturnValues(callers[i].getStart(), callers[i].getCount());
            String msg = String.format("Expected not failed for caller-%d: %s.", i, callers[i]);
            assertFalse(msg, callers[i].failed);
        }
        for (int i = 0; i < clientCount; i++) {
            clients[i].stop();
        }
        server.stop();
    }

    /**
     * Test if (1) the rpc server uses the call id/retry provided by the rpc
     * client, and (2) the rpc client receives the same call id/retry from the rpc
     * server.
     *
     * @throws ExecutionException
     * @throws InterruptedException
     */
    @Test(timeout = 60000)
    public void testCallIdAndRetry() throws IOException, InterruptedException, ExecutionException {
        final Map<Integer, CallInfo> infoMap = new HashMap<Integer, CallInfo>();

        // Override client to store the call info and check response
        final Client client = new Client(LongWritable.class, conf) {
            @Override
            Call createCall(RpcKind rpcKind, Writable rpcRequest) {
                // Set different call id and retry count for the next call
                Client.setCallIdAndRetryCount(Client.nextCallId(), TestIPC.RANDOM.nextInt(255), null);

                final Call call = super.createCall(rpcKind, rpcRequest);

                CallInfo info = new CallInfo();
                info.id = call.id;
                info.retry = call.retry;
                infoMap.put(call.id, info);

                return call;
            }

            @Override
            void checkResponse(RpcResponseHeaderProto header) throws IOException {
                super.checkResponse(header);
                Assert.assertEquals(infoMap.get(header.getCallId()).retry, header.getRetryCount());
            }
        };

        // Attach a listener that tracks every call received by the server.
        final TestServer server = new TestIPC.TestServer(1, false, conf);
        server.callListener = new Runnable() {
            @Override
            public void run() {
                Assert.assertEquals(infoMap.get(Server.getCallId()).retry, Server.getCallRetryCount());
            }
        };

        try {
            InetSocketAddress addr = NetUtils.getConnectAddress(server);
            server.start();
            final AsyncCaller caller = new AsyncCaller(client, addr, 4);
            caller.run();
            caller.assertReturnValues();
        } finally {
            client.stop();
            server.stop();
        }
    }

    /**
     * Test if the rpc server gets the retry count from client.
     *
     * @throws ExecutionException
     * @throws InterruptedException
     */
    @Test(timeout = 60000)
    public void testCallRetryCount() throws IOException, InterruptedException, ExecutionException {
        final int retryCount = 255;
        // Override client to store the call id
        final Client client = new Client(LongWritable.class, conf);
        Client.setCallIdAndRetryCount(Client.nextCallId(), retryCount, null);

        // Attach a listener that tracks every call ID received by the server.
        final TestServer server = new TestIPC.TestServer(1, false, conf);
        server.callListener = new Runnable() {
            @Override
            public void run() {
                // we have not set the retry count for the client, thus on the server
                // side we should see retry count as 0
                Assert.assertEquals(retryCount, Server.getCallRetryCount());
            }
        };

        try {
            InetSocketAddress addr = NetUtils.getConnectAddress(server);
            server.start();
            final AsyncCaller caller = new AsyncCaller(client, addr, 10);
            caller.run();
            caller.assertReturnValues();
        } finally {
            client.stop();
            server.stop();
        }
    }

    /**
     * Test if the rpc server gets the default retry count (0) from client.
     *
     * @throws ExecutionException
     * @throws InterruptedException
     */
    @Test(timeout = 60000)
    public void testInitialCallRetryCount() throws IOException, InterruptedException, ExecutionException {
        // Override client to store the call id
        final Client client = new Client(LongWritable.class, conf);

        // Attach a listener that tracks every call ID received by the server.
        final TestServer server = new TestIPC.TestServer(1, false, conf);
        server.callListener = new Runnable() {
            @Override
            public void run() {
                // we have not set the retry count for the client, thus on the server
                // side we should see retry count as 0
                Assert.assertEquals(0, Server.getCallRetryCount());
            }
        };

        try {
            InetSocketAddress addr = NetUtils.getConnectAddress(server);
            server.start();
            final AsyncCaller caller = new AsyncCaller(client, addr, 10);
            caller.run();
            caller.assertReturnValues();
        } finally {
            client.stop();
            server.stop();
        }
    }

    /**
     * Tests that client generates a unique sequential call ID for each RPC call,
     * even if multiple threads are using the same client.
     *
     * @throws InterruptedException
     * @throws ExecutionException
     */
    @Test(timeout = 60000)
    public void testUniqueSequentialCallIds() throws IOException, InterruptedException, ExecutionException {
        int serverThreads = 10, callerCount = 100, perCallerCallCount = 100;
        TestServer server = new TestIPC.TestServer(serverThreads, false, conf);

        // Attach a listener that tracks every call ID received by the server. This
        // list must be synchronized, because multiple server threads will add to
        // it.
        final List<Integer> callIds = Collections.synchronizedList(new ArrayList<Integer>());
        server.callListener = new Runnable() {
            @Override
            public void run() {
                callIds.add(Server.getCallId());
            }
        };

        Client client = new Client(LongWritable.class, conf);

        try {
            InetSocketAddress addr = NetUtils.getConnectAddress(server);
            server.start();
            AsyncCaller[] callers = new AsyncCaller[callerCount];
            for (int i = 0; i < callerCount; ++i) {
                callers[i] = new AsyncCaller(client, addr, perCallerCallCount);
                callers[i].start();
            }
            for (int i = 0; i < callerCount; ++i) {
                callers[i].join();
                callers[i].assertReturnValues();
            }
        } finally {
            client.stop();
            server.stop();
        }

        int expectedCallCount = callerCount * perCallerCallCount;
        assertEquals(expectedCallCount, callIds.size());

        // It is not guaranteed that the server executes requests in sequential
        // order
        // of client call ID, so we must sort the call IDs before checking that it
        // contains every expected value.
        Collections.sort(callIds);
        final int startID = callIds.get(0).intValue();
        for (int i = 0; i < expectedCallCount; ++i) {
            assertEquals(startID + i, callIds.get(i).intValue());
        }
    }
}