org.apache.reef.io.network.NetworkConnectionServiceTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.reef.io.network.NetworkConnectionServiceTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.reef.io.network;

import org.apache.commons.lang3.StringUtils;
import org.apache.reef.exception.evaluator.NetworkException;
import org.apache.reef.io.network.util.*;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.exceptions.InjectionException;
import org.apache.reef.wake.Identifier;
import org.apache.reef.wake.IdentifierFactory;
import org.apache.reef.wake.remote.Codec;
import org.apache.reef.wake.remote.address.LocalAddressProvider;
import org.apache.reef.wake.remote.impl.ObjectSerializableCodec;
import org.junit.Assume;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;

import java.util.concurrent.*;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Default Network connection service test.
 */
public class NetworkConnectionServiceTest {
    private static final Logger LOG = Logger.getLogger(NetworkConnectionServiceTest.class.getName());

    private final LocalAddressProvider localAddressProvider;
    private final String localAddress;
    private final Identifier groupCommClientId;
    private final Identifier shuffleClientId;

    public NetworkConnectionServiceTest() throws InjectionException {
        localAddressProvider = Tang.Factory.getTang().newInjector().getInstance(LocalAddressProvider.class);
        localAddress = localAddressProvider.getLocalAddress();

        final IdentifierFactory idFac = new StringIdentifierFactory();
        this.groupCommClientId = idFac.getNewInstance("groupComm");
        this.shuffleClientId = idFac.getNewInstance("shuffle");
    }

    @Rule
    public TestName name = new TestName();

    private void runMessagingNetworkConnectionService(final Codec<String> codec) throws Exception {
        final int numMessages = 2000;
        final Monitor monitor = new Monitor();
        try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(
                localAddress)) {
            messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor, codec);

            try (final Connection<String> conn = messagingTestService
                    .getConnectionFromSenderToReceiver(groupCommClientId)) {
                try {
                    conn.open();
                    for (int count = 0; count < numMessages; ++count) {
                        // send messages to the receiver.
                        conn.write("hello" + count);
                    }
                    monitor.mwait();
                } catch (final NetworkException e) {
                    e.printStackTrace();
                    throw new RuntimeException(e);
                }
            }
        }
    }

    /**
     * NetworkConnectionService messaging test.
     */
    @Test
    public void testMessagingNetworkConnectionService() throws Exception {
        LOG.log(Level.FINEST, name.getMethodName());
        runMessagingNetworkConnectionService(new StringCodec());
    }

    /**
     * NetworkConnectionService streaming messaging test.
     */
    @Test
    public void testStreamingMessagingNetworkConnectionService() throws Exception {
        LOG.log(Level.FINEST, name.getMethodName());
        runMessagingNetworkConnectionService(new StreamingStringCodec());
    }

    public void runNetworkConnServiceWithMultipleConnFactories(final Codec<String> stringCodec,
            final Codec<Integer> integerCodec) throws Exception {
        final ExecutorService executor = Executors.newFixedThreadPool(5);

        final int groupcommMessages = 1000;
        final Monitor monitor = new Monitor();
        try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(
                localAddress)) {

            messagingTestService.registerTestConnectionFactory(groupCommClientId, groupcommMessages, monitor,
                    stringCodec);

            final int shuffleMessages = 2000;
            final Monitor monitor2 = new Monitor();
            messagingTestService.registerTestConnectionFactory(shuffleClientId, shuffleMessages, monitor2,
                    integerCodec);

            executor.submit(new Runnable() {
                @Override
                public void run() {
                    try (final Connection<String> conn = messagingTestService
                            .getConnectionFromSenderToReceiver(groupCommClientId)) {
                        conn.open();
                        for (int count = 0; count < groupcommMessages; ++count) {
                            // send messages to the receiver.
                            conn.write("hello" + count);
                        }
                        monitor.mwait();
                    } catch (final Exception e) {
                        e.printStackTrace();
                        throw new RuntimeException(e);
                    }
                }
            });

            executor.submit(new Runnable() {
                @Override
                public void run() {
                    try (final Connection<Integer> conn = messagingTestService
                            .getConnectionFromSenderToReceiver(shuffleClientId)) {
                        conn.open();
                        for (int count = 0; count < shuffleMessages; ++count) {
                            // send messages to the receiver.
                            conn.write(count);
                        }
                        monitor2.mwait();
                    } catch (final Exception e) {
                        e.printStackTrace();
                        throw new RuntimeException(e);
                    }
                }
            });

            monitor.mwait();
            monitor2.mwait();
            executor.shutdown();
        }
    }

    /**
     * Test NetworkService registering multiple connection factories.
     */
    @Test
    public void testMultipleConnectionFactoriesTest() throws Exception {
        LOG.log(Level.FINEST, name.getMethodName());
        runNetworkConnServiceWithMultipleConnFactories(new StringCodec(), new ObjectSerializableCodec<Integer>());
    }

    /**
     * Test NetworkService registering multiple connection factories with Streamingcodec.
     */
    @Test
    public void testMultipleConnectionFactoriesStreamingTest() throws Exception {
        LOG.log(Level.FINEST, name.getMethodName());
        runNetworkConnServiceWithMultipleConnFactories(new StreamingStringCodec(), new StreamingIntegerCodec());
    }

    /**
     * NetworkService messaging rate benchmark.
     */
    @Test
    public void testMessagingNetworkConnServiceRate() throws Exception {

        Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

        LOG.log(Level.FINEST, name.getMethodName());

        final int[] messageSizes = { 1, 16, 32, 64, 512, 64 * 1024, 1024 * 1024 };

        for (final int size : messageSizes) {
            final String message = StringUtils.repeat('1', size);
            final int numMessages = 300000 / (Math.max(1, size / 512));
            final Monitor monitor = new Monitor();
            final Codec<String> codec = new StringCodec();
            try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(
                    localAddress)) {
                messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor, codec);

                try (final Connection<String> conn = messagingTestService
                        .getConnectionFromSenderToReceiver(groupCommClientId)) {

                    final long start = System.currentTimeMillis();
                    try {
                        conn.open();
                        for (int count = 0; count < numMessages; ++count) {
                            // send messages to the receiver.
                            conn.write(message);
                        }
                        monitor.mwait();
                    } catch (final NetworkException e) {
                        e.printStackTrace();
                        throw new RuntimeException(e);
                    }
                    final long end = System.currentTimeMillis();

                    final double runtime = ((double) end - start) / 1000;
                    LOG.log(Level.INFO, "size: " + size + "; messages/s: " + numMessages / runtime
                            + " bandwidth(bytes/s): " + ((double) numMessages * 2 * size) / runtime); // x2 for unicode chars
                }
            }
        }
    }

    /**
     * NetworkService messaging rate benchmark.
     */
    @Test
    public void testMessagingNetworkConnServiceRateDisjoint() throws Exception {

        Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

        LOG.log(Level.FINEST, name.getMethodName());

        final BlockingQueue<Object> barrier = new LinkedBlockingQueue<>();

        final int numThreads = 4;
        final int size = 2000;
        final int numMessages = 300000 / (Math.max(1, size / 512));
        final int totalNumMessages = numMessages * numThreads;
        final String message = StringUtils.repeat('1', size);

        final ExecutorService e = Executors.newCachedThreadPool();
        for (int t = 0; t < numThreads; t++) {
            final int tt = t;

            e.submit(new Runnable() {
                public void run() {
                    try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(
                            localAddress)) {
                        final Monitor monitor = new Monitor();
                        final Codec<String> codec = new StringCodec();

                        messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor,
                                codec);
                        try (final Connection<String> conn = messagingTestService
                                .getConnectionFromSenderToReceiver(groupCommClientId)) {
                            try {
                                conn.open();
                                for (int count = 0; count < numMessages; ++count) {
                                    // send messages to the receiver.
                                    conn.write(message);
                                }
                                monitor.mwait();
                            } catch (final NetworkException e) {
                                e.printStackTrace();
                                throw new RuntimeException(e);
                            }
                        }
                    } catch (final Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            });
        }

        // start and time
        final long start = System.currentTimeMillis();
        final Object ignore = new Object();
        for (int i = 0; i < numThreads; i++) {
            barrier.add(ignore);
        }
        e.shutdown();
        e.awaitTermination(100, TimeUnit.SECONDS);
        final long end = System.currentTimeMillis();
        final double runtime = ((double) end - start) / 1000;
        LOG.log(Level.INFO, "size: " + size + "; messages/s: " + totalNumMessages / runtime
                + " bandwidth(bytes/s): " + ((double) totalNumMessages * 2 * size) / runtime); // x2 for unicode chars
    }

    @Test
    public void testMultithreadedSharedConnMessagingNetworkConnServiceRate() throws Exception {

        Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

        LOG.log(Level.FINEST, name.getMethodName());

        final int[] messageSizes = { 2000 }; // {1,16,32,64,512,64*1024,1024*1024};

        for (final int size : messageSizes) {
            final String message = StringUtils.repeat('1', size);
            final int numMessages = 300000 / (Math.max(1, size / 512));
            final int numThreads = 2;
            final int totalNumMessages = numMessages * numThreads;
            final Monitor monitor = new Monitor();
            final Codec<String> codec = new StringCodec();
            try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(
                    localAddress)) {
                messagingTestService.registerTestConnectionFactory(groupCommClientId, totalNumMessages, monitor,
                        codec);

                final ExecutorService e = Executors.newCachedThreadPool();

                try (final Connection<String> conn = messagingTestService
                        .getConnectionFromSenderToReceiver(groupCommClientId)) {

                    final long start = System.currentTimeMillis();
                    for (int i = 0; i < numThreads; i++) {
                        e.submit(new Runnable() {
                            @Override
                            public void run() {

                                try {
                                    conn.open();
                                    for (int count = 0; count < numMessages; ++count) {
                                        // send messages to the receiver.
                                        conn.write(message);
                                    }
                                } catch (final Exception e) {
                                    throw new RuntimeException(e);
                                }
                            }
                        });
                    }

                    e.shutdown();
                    e.awaitTermination(30, TimeUnit.SECONDS);
                    monitor.mwait();
                    final long end = System.currentTimeMillis();
                    final double runtime = ((double) end - start) / 1000;
                    LOG.log(Level.INFO, "size: " + size + "; messages/s: " + totalNumMessages / runtime
                            + " bandwidth(bytes/s): " + ((double) totalNumMessages * 2 * size) / runtime); // x2 for unicode chars
                }
            }
        }
    }

    /**
     * NetworkService messaging rate benchmark.
     */
    @Test
    public void testMessagingNetworkConnServiceBatchingRate() throws Exception {

        Assume.assumeFalse("Use log level INFO to run benchmarking", LOG.isLoggable(Level.FINEST));

        LOG.log(Level.FINEST, name.getMethodName());

        final int batchSize = 1024 * 1024;
        final int[] messageSizes = { 32, 64, 512 };

        for (final int size : messageSizes) {
            final String message = StringUtils.repeat('1', batchSize);
            final int numMessages = 300 / (Math.max(1, size / 512));
            final Monitor monitor = new Monitor();
            final Codec<String> codec = new StringCodec();
            try (final NetworkMessagingTestService messagingTestService = new NetworkMessagingTestService(
                    localAddress)) {
                messagingTestService.registerTestConnectionFactory(groupCommClientId, numMessages, monitor, codec);
                try (final Connection<String> conn = messagingTestService
                        .getConnectionFromSenderToReceiver(groupCommClientId)) {
                    final long start = System.currentTimeMillis();
                    try {
                        conn.open();
                        for (int i = 0; i < numMessages; i++) {
                            conn.write(message);
                        }
                        monitor.mwait();
                    } catch (final NetworkException e) {
                        e.printStackTrace();
                        throw new RuntimeException(e);
                    }

                    final long end = System.currentTimeMillis();
                    final double runtime = ((double) end - start) / 1000;
                    final long numAppMessages = numMessages * batchSize / size;
                    LOG.log(Level.INFO, "size: " + size + "; messages/s: " + numAppMessages / runtime
                            + " bandwidth(bytes/s): " + ((double) numAppMessages * 2 * size) / runtime); // x2 for unicode chars
                }
            }
        }
    }
}