org.springframework.integration.ip.tcp.connection.ConnectionFactoryTests.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.integration.ip.tcp.connection.ConnectionFactoryTests.java

Source

/*
 * Copyright 2002-2016 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.integration.ip.tcp.connection;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.contains;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
import org.mockito.ArgumentCaptor;

import org.springframework.beans.DirectFieldAccessor;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.integration.channel.NullChannel;
import org.springframework.integration.context.IntegrationContextUtils;
import org.springframework.integration.ip.config.TcpConnectionFactoryFactoryBean;
import org.springframework.integration.ip.event.IpIntegrationEvent;
import org.springframework.integration.ip.tcp.TcpReceivingChannelAdapter;
import org.springframework.integration.test.support.LogAdjustingTestSupport;
import org.springframework.integration.test.util.TestUtils;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;

/**
 * @author Gary Russell
 * @author Artem Bilan
 * @since 3.0
 *
 */
public class ConnectionFactoryTests extends LogAdjustingTestSupport {

    @Test
    public void factoryBeanTests() {
        TcpConnectionFactoryFactoryBean fb = new TcpConnectionFactoryFactoryBean("client");
        assertEquals(AbstractClientConnectionFactory.class, fb.getObjectType());
        fb = new TcpConnectionFactoryFactoryBean("server");
        assertEquals(AbstractServerConnectionFactory.class, fb.getObjectType());
        fb = new TcpConnectionFactoryFactoryBean();
        assertEquals(AbstractConnectionFactory.class, fb.getObjectType());
    }

    @Test
    public void testObtainConnectionIdsNet() throws Exception {
        TcpNetServerConnectionFactory serverFactory = new TcpNetServerConnectionFactory(0);
        testObtainConnectionIds(serverFactory);
    }

    @Test
    public void testObtainConnectionIdsNio() throws Exception {
        TcpNioServerConnectionFactory serverFactory = new TcpNioServerConnectionFactory(0);
        testObtainConnectionIds(serverFactory);
    }

    public void testObtainConnectionIds(AbstractServerConnectionFactory serverFactory) throws Exception {
        final List<IpIntegrationEvent> events = Collections.synchronizedList(new ArrayList<IpIntegrationEvent>());
        int expectedEvents = serverFactory instanceof TcpNetServerConnectionFactory ? 7 // Listening, + OPEN, CLOSE, EXCEPTION for each side
                : 5; // Listening, + OPEN, CLOSE (but we *might* get exceptions, depending on timing).
        final CountDownLatch serverListeningLatch = new CountDownLatch(1);
        final CountDownLatch eventLatch = new CountDownLatch(expectedEvents);
        ApplicationEventPublisher publisher = new ApplicationEventPublisher() {

            @Override
            public void publishEvent(ApplicationEvent event) {
                LogFactory.getLog(this.getClass()).trace("Received: " + event);
                events.add((IpIntegrationEvent) event);
                if (event instanceof TcpConnectionServerListeningEvent) {
                    serverListeningLatch.countDown();
                }
                eventLatch.countDown();
            }

            @Override
            public void publishEvent(Object event) {

            }

        };
        serverFactory.setBeanName("serverFactory");
        serverFactory.setApplicationEventPublisher(publisher);
        serverFactory = spy(serverFactory);
        final CountDownLatch serverConnectionInitLatch = new CountDownLatch(1);
        doAnswer(invocation -> {
            Object result = invocation.callRealMethod();
            serverConnectionInitLatch.countDown();
            return result;
        }).when(serverFactory).wrapConnection(any(TcpConnectionSupport.class));
        ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
        scheduler.setPoolSize(10);
        scheduler.afterPropertiesSet();
        BeanFactory bf = mock(BeanFactory.class);
        when(bf.containsBean(IntegrationContextUtils.TASK_SCHEDULER_BEAN_NAME)).thenReturn(true);
        when(bf.getBean(IntegrationContextUtils.TASK_SCHEDULER_BEAN_NAME, TaskScheduler.class))
                .thenReturn(scheduler);
        serverFactory.setBeanFactory(bf);
        TcpReceivingChannelAdapter adapter = new TcpReceivingChannelAdapter();
        adapter.setOutputChannel(new NullChannel());
        adapter.setConnectionFactory(serverFactory);
        adapter.start();
        assertTrue("Listening event not received", serverListeningLatch.await(10, TimeUnit.SECONDS));
        assertThat(events.get(0), instanceOf(TcpConnectionServerListeningEvent.class));
        assertThat(((TcpConnectionServerListeningEvent) events.get(0)).getPort(), equalTo(serverFactory.getPort()));
        int port = serverFactory.getPort();
        TcpNetClientConnectionFactory clientFactory = new TcpNetClientConnectionFactory("localhost", port);
        clientFactory.registerListener(message -> false);
        clientFactory.setBeanName("clientFactory");
        clientFactory.setApplicationEventPublisher(publisher);
        clientFactory.start();
        TcpConnectionSupport client = clientFactory.getConnection();
        List<String> clients = clientFactory.getOpenConnectionIds();
        assertEquals(1, clients.size());
        assertTrue(clients.contains(client.getConnectionId()));
        assertTrue("Server connection failed to register", serverConnectionInitLatch.await(1, TimeUnit.SECONDS));
        List<String> servers = serverFactory.getOpenConnectionIds();
        assertEquals(1, servers.size());
        assertTrue(serverFactory.closeConnection(servers.get(0)));
        servers = serverFactory.getOpenConnectionIds();
        assertEquals(0, servers.size());
        int n = 0;
        clients = clientFactory.getOpenConnectionIds();
        while (n++ < 100 && clients.size() > 0) {
            Thread.sleep(100);
            clients = clientFactory.getOpenConnectionIds();
        }
        assertEquals(0, clients.size());
        assertTrue(eventLatch.await(10, TimeUnit.SECONDS));
        assertThat("Expected at least " + expectedEvents + " events; got: " + events.size() + " : " + events,
                events.size(), greaterThanOrEqualTo(expectedEvents));

        FooEvent event = new FooEvent(client, "foo");
        client.publishEvent(event);
        assertThat("Expected at least " + expectedEvents + " events; got: " + events.size() + " : " + events,
                events.size(), greaterThanOrEqualTo(expectedEvents + 1));

        try {
            event = new FooEvent(mock(TcpConnectionSupport.class), "foo");
            client.publishEvent(event);
            fail("Expected exception");
        } catch (IllegalArgumentException e) {
            assertTrue("Can only publish events with this as the source".equals(e.getMessage()));
        }

        SocketAddress address = serverFactory.getServerSocketAddress();
        if (address instanceof InetSocketAddress) {
            InetSocketAddress inetAddress = (InetSocketAddress) address;
            assertEquals(port, inetAddress.getPort());
        }
        serverFactory.stop();
        scheduler.shutdown();
    }

    @Test
    public void testEarlyCloseNet() throws Exception {
        AbstractServerConnectionFactory factory = new TcpNetServerConnectionFactory(0);
        testEarlyClose(factory, "serverSocket", " stopped before accept");
    }

    @Test
    public void testEarlyCloseNio() throws Exception {
        AbstractServerConnectionFactory factory = new TcpNioServerConnectionFactory(0);
        testEarlyClose(factory, "serverChannel", " stopped before registering the server channel");
    }

    private void testEarlyClose(final AbstractServerConnectionFactory factory, String property, String message)
            throws Exception {
        factory.setApplicationEventPublisher(mock(ApplicationEventPublisher.class));
        factory.setBeanName("foo");
        factory.registerListener(mock(TcpListener.class));
        factory.afterPropertiesSet();
        Log logger = spy(TestUtils.getPropertyValue(factory, "logger", Log.class));
        new DirectFieldAccessor(factory).setPropertyValue("logger", logger);
        final CountDownLatch latch1 = new CountDownLatch(1);
        final CountDownLatch latch2 = new CountDownLatch(1);
        final CountDownLatch latch3 = new CountDownLatch(1);
        when(logger.isInfoEnabled()).thenReturn(true);
        when(logger.isDebugEnabled()).thenReturn(true);
        doAnswer(invocation -> {
            latch1.countDown();
            // wait until the stop nulls the channel
            latch2.await(10, TimeUnit.SECONDS);
            return null;
        }).when(logger).info(contains("Listening"));
        doAnswer(invocation -> {
            latch3.countDown();
            return null;
        }).when(logger).debug(contains(message));
        factory.start();
        assertTrue("missing info log", latch1.await(10, TimeUnit.SECONDS));
        // stop on a different thread because it waits for the executor
        Executors.newSingleThreadExecutor().execute(() -> factory.stop());
        int n = 0;
        DirectFieldAccessor accessor = new DirectFieldAccessor(factory);
        while (n++ < 200 && accessor.getPropertyValue(property) != null) {
            Thread.sleep(100);
        }
        assertTrue("Stop was not invoked in time", n < 200);
        latch2.countDown();
        assertTrue("missing debug log", latch3.await(10, TimeUnit.SECONDS));
        String expected = "foo, port=" + factory.getPort() + message;
        ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
        verify(logger, atLeast(1)).debug(captor.capture());
        assertThat(captor.getAllValues(), hasItem(expected));
        factory.stop();
    }

    @SuppressWarnings("serial")
    private class FooEvent extends TcpConnectionOpenEvent {

        FooEvent(TcpConnectionSupport connection, String connectionFactoryName) {
            super(connection, connectionFactoryName);
        }

    }

}