org.springframework.integration.aggregator.AggregatorTests.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.integration.aggregator.AggregatorTests.java

Source

/*
 * Copyright 2002-2016 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.integration.aggregator;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.locks.ReentrantLock;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

import org.springframework.beans.factory.BeanFactory;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.handler.AbstractMessageHandler;
import org.springframework.integration.store.MessageGroup;
import org.springframework.integration.store.SimpleMessageGroupFactory;
import org.springframework.integration.store.SimpleMessageStore;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandlingException;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.util.StopWatch;

/**
 * @author Mark Fisher
 * @author Marius Bogoevici
 * @author Iwein Fuld
 * @author Gary Russell
 * @author Artem Bilan
 */
public class AggregatorTests {

    private static final Log logger = LogFactory.getLog(AggregatorTests.class);

    private AggregatingMessageHandler aggregator;

    private final SimpleMessageStore store = new SimpleMessageStore(50);

    private final List<MessageGroupExpiredEvent> expiryEvents = new ArrayList<MessageGroupExpiredEvent>();

    @Before
    public void configureAggregator() {
        this.aggregator = new AggregatingMessageHandler(new MultiplyingProcessor(), store);
        this.aggregator.setBeanFactory(mock(BeanFactory.class));
        this.aggregator.setApplicationEventPublisher(event -> expiryEvents.add((MessageGroupExpiredEvent) event));
        this.aggregator.setBeanName("testAggregator");
        this.aggregator.afterPropertiesSet();
        expiryEvents.clear();
    }

    @Test
    public void testAggPerf() throws InterruptedException, ExecutionException, TimeoutException {
        AggregatingMessageHandler handler = new AggregatingMessageHandler(
                new DefaultAggregatingMessageGroupProcessor());
        handler.setCorrelationStrategy(message -> "foo");
        handler.setReleaseStrategy(new MessageCountReleaseStrategy(60000));
        handler.setExpireGroupsUponCompletion(true);
        handler.setSendPartialResultOnExpiry(true);
        DirectChannel outputChannel = new DirectChannel();
        handler.setOutputChannel(outputChannel);

        final CompletableFuture<Collection<?>> resultFuture = new CompletableFuture<>();
        outputChannel.subscribe(message -> {
            Collection<?> payload = (Collection<?>) message.getPayload();
            logger.warn("Received " + payload.size());
            resultFuture.complete(payload);
        });

        SimpleMessageStore store = new SimpleMessageStore();

        SimpleMessageGroupFactory messageGroupFactory = new SimpleMessageGroupFactory(
                SimpleMessageGroupFactory.GroupType.BLOCKING_QUEUE);

        store.setMessageGroupFactory(messageGroupFactory);

        handler.setMessageStore(store);

        Message<?> message = new GenericMessage<String>("foo");
        StopWatch stopwatch = new StopWatch();
        stopwatch.start();
        for (int i = 0; i < 120000; i++) {
            if (i % 10000 == 0) {
                stopwatch.stop();
                logger.warn("Sent " + i + " in " + stopwatch.getTotalTimeSeconds() + " (10k in "
                        + stopwatch.getLastTaskTimeMillis() + "ms)");
                stopwatch.start();
            }
            handler.handleMessage(message);
        }
        stopwatch.stop();
        logger.warn("Sent " + 120000 + " in " + stopwatch.getTotalTimeSeconds() + " (10k in "
                + stopwatch.getLastTaskTimeMillis() + "ms)");

        Collection<?> result = resultFuture.get(10, TimeUnit.SECONDS);
        assertNotNull(result);
        assertEquals(60000, result.size());
    }

    @Test
    public void testAggPerfDefaultPartial() throws InterruptedException, ExecutionException, TimeoutException {
        AggregatingMessageHandler handler = new AggregatingMessageHandler(
                new DefaultAggregatingMessageGroupProcessor());
        handler.setCorrelationStrategy(message -> "foo");
        handler.setReleasePartialSequences(true);
        DirectChannel outputChannel = new DirectChannel();
        handler.setOutputChannel(outputChannel);

        final CompletableFuture<Collection<?>> resultFuture = new CompletableFuture<>();
        outputChannel.subscribe(message -> {
            Collection<?> payload = (Collection<?>) message.getPayload();
            logger.warn("Received " + payload.size());
            resultFuture.complete(payload);
        });

        SimpleMessageStore store = new SimpleMessageStore();

        SimpleMessageGroupFactory messageGroupFactory = new SimpleMessageGroupFactory(
                SimpleMessageGroupFactory.GroupType.BLOCKING_QUEUE);

        store.setMessageGroupFactory(messageGroupFactory);

        handler.setMessageStore(store);

        StopWatch stopwatch = new StopWatch();
        stopwatch.start();
        for (int i = 0; i < 120000; i++) {
            if (i % 10000 == 0) {
                stopwatch.stop();
                logger.warn("Sent " + i + " in " + stopwatch.getTotalTimeSeconds() + " (10k in "
                        + stopwatch.getLastTaskTimeMillis() + "ms)");
                stopwatch.start();
            }
            handler.handleMessage(
                    MessageBuilder.withPayload("foo").setSequenceSize(120000).setSequenceNumber(i + 1).build());
        }
        stopwatch.stop();
        logger.warn("Sent " + 120000 + " in " + stopwatch.getTotalTimeSeconds() + " (10k in "
                + stopwatch.getLastTaskTimeMillis() + "ms)");

        Collection<?> result = resultFuture.get(10, TimeUnit.SECONDS);
        assertNotNull(result);
        assertEquals(120000, result.size());
        assertThat(stopwatch.getTotalTimeSeconds(), lessThan(60.0)); // actually < 2.0, was many minutes
    }

    @Test
    public void testCustomAggPerf() throws InterruptedException, ExecutionException, TimeoutException {
        class CustomHandler extends AbstractMessageHandler {

            // custom aggregator, only handles a single correlation

            private final ReentrantLock lock = new ReentrantLock();

            private final Collection<Message<?>> messages = new ArrayList<Message<?>>(60000);

            private final MessageChannel outputChannel;

            private CustomHandler(MessageChannel outputChannel) {
                this.outputChannel = outputChannel;
            }

            @Override
            public void handleMessageInternal(Message<?> requestMessage) {
                lock.lock();
                try {
                    this.messages.add(requestMessage);
                    if (this.messages.size() == 60000) {
                        List<Object> payloads = new ArrayList<Object>(this.messages.size());
                        for (Message<?> message : this.messages) {
                            payloads.add(message.getPayload());
                        }
                        this.messages.clear();
                        outputChannel.send(getMessageBuilderFactory().withPayload(payloads)
                                .copyHeaders(requestMessage.getHeaders()).build());
                    }
                } finally {
                    lock.unlock();
                }
            }

        }

        DirectChannel outputChannel = new DirectChannel();
        CustomHandler handler = new CustomHandler(outputChannel);

        final CompletableFuture<Collection<?>> resultFuture = new CompletableFuture<>();
        outputChannel.subscribe(message -> {
            Collection<?> payload = (Collection<?>) message.getPayload();
            logger.warn("Received " + payload.size());
            resultFuture.complete(payload);
        });
        Message<?> message = new GenericMessage<String>("foo");
        StopWatch stopwatch = new StopWatch();
        stopwatch.start();
        for (int i = 0; i < 120000; i++) {
            if (i % 10000 == 0) {
                stopwatch.stop();
                logger.warn("Sent " + i + " in " + stopwatch.getTotalTimeSeconds() + " (10k in "
                        + stopwatch.getLastTaskTimeMillis() + "ms)");
                stopwatch.start();
            }
            handler.handleMessage(message);
        }
        stopwatch.stop();
        logger.warn("Sent " + 120000 + " in " + stopwatch.getTotalTimeSeconds() + " (10k in "
                + stopwatch.getLastTaskTimeMillis() + "ms)");

        Collection<?> result = resultFuture.get(10, TimeUnit.SECONDS);
        assertNotNull(result);
        assertEquals(60000, result.size());
    }

    @Test
    public void testCompleteGroupWithinTimeout() throws InterruptedException {
        QueueChannel replyChannel = new QueueChannel();
        Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
        Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
        Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null);

        this.aggregator.handleMessage(message1);
        this.aggregator.handleMessage(message2);
        this.aggregator.handleMessage(message3);

        Message<?> reply = replyChannel.receive(10000);
        assertNotNull(reply);
        assertEquals(reply.getPayload(), 105);
    }

    @Test
    public void testShouldNotSendPartialResultOnTimeoutByDefault() throws InterruptedException {
        QueueChannel discardChannel = new QueueChannel();
        this.aggregator.setDiscardChannel(discardChannel);
        QueueChannel replyChannel = new QueueChannel();
        Message<?> message = createMessage(3, "ABC", 2, 1, replyChannel, null);
        this.aggregator.handleMessage(message);
        this.store.expireMessageGroups(-10000);
        Message<?> reply = replyChannel.receive(1000);
        assertNull("No message should have been sent normally", reply);
        Message<?> discardedMessage = discardChannel.receive(1000);
        assertNotNull("A message should have been discarded", discardedMessage);
        assertEquals(message, discardedMessage);
        assertEquals(1, expiryEvents.size());
        assertSame(this.aggregator, expiryEvents.get(0).getSource());
        assertEquals("ABC", this.expiryEvents.get(0).getGroupId());
        assertEquals(1, this.expiryEvents.get(0).getMessageCount());
        assertEquals(true, this.expiryEvents.get(0).isDiscarded());
    }

    @Test
    public void testShouldSendPartialResultOnTimeoutTrue() throws InterruptedException {
        this.aggregator.setSendPartialResultOnExpiry(true);
        QueueChannel replyChannel = new QueueChannel();
        Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
        Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
        this.aggregator.handleMessage(message1);
        this.aggregator.handleMessage(message2);
        this.store.expireMessageGroups(-10000);
        Message<?> reply = replyChannel.receive(1000);
        assertNotNull("A reply message should have been received", reply);
        assertEquals(15, reply.getPayload());
        assertEquals(1, expiryEvents.size());
        assertSame(this.aggregator, expiryEvents.get(0).getSource());
        assertEquals("ABC", this.expiryEvents.get(0).getGroupId());
        assertEquals(2, this.expiryEvents.get(0).getMessageCount());
        assertEquals(false, this.expiryEvents.get(0).isDiscarded());
        Message<?> message3 = createMessage(5, "ABC", 3, 3, replyChannel, null);
        this.aggregator.handleMessage(message3);
        assertEquals(1, this.store.getMessageGroup("ABC").size());
    }

    @Test
    public void testGroupRemainsAfterTimeout() throws InterruptedException {
        this.aggregator.setSendPartialResultOnExpiry(true);
        this.aggregator.setExpireGroupsUponTimeout(false);
        QueueChannel replyChannel = new QueueChannel();
        QueueChannel discardChannel = new QueueChannel();
        this.aggregator.setDiscardChannel(discardChannel);
        Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
        Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
        this.aggregator.handleMessage(message1);
        this.aggregator.handleMessage(message2);
        this.store.expireMessageGroups(-10000);
        Message<?> reply = replyChannel.receive(1000);
        assertNotNull("A reply message should have been received", reply);
        assertEquals(15, reply.getPayload());
        assertEquals(1, expiryEvents.size());
        assertSame(this.aggregator, expiryEvents.get(0).getSource());
        assertEquals("ABC", this.expiryEvents.get(0).getGroupId());
        assertEquals(2, this.expiryEvents.get(0).getMessageCount());
        assertEquals(false, this.expiryEvents.get(0).isDiscarded());
        assertEquals(0, this.store.getMessageGroup("ABC").size());
        Message<?> message3 = createMessage(5, "ABC", 3, 3, replyChannel, null);
        this.aggregator.handleMessage(message3);
        assertEquals(0, this.store.getMessageGroup("ABC").size());
        Message<?> discardedMessage = discardChannel.receive(1000);
        assertNotNull("A message should have been discarded", discardedMessage);
        assertSame(message3, discardedMessage);
    }

    @Test
    public void testMultipleGroupsSimultaneously() throws InterruptedException {
        QueueChannel replyChannel1 = new QueueChannel();
        QueueChannel replyChannel2 = new QueueChannel();
        Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel1, null);
        Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel1, null);
        Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel1, null);
        Message<?> message4 = createMessage(11, "XYZ", 3, 1, replyChannel2, null);
        Message<?> message5 = createMessage(13, "XYZ", 3, 2, replyChannel2, null);
        Message<?> message6 = createMessage(17, "XYZ", 3, 3, replyChannel2, null);
        aggregator.handleMessage(message1);
        aggregator.handleMessage(message5);
        aggregator.handleMessage(message3);
        aggregator.handleMessage(message6);
        aggregator.handleMessage(message4);
        aggregator.handleMessage(message2);
        @SuppressWarnings("unchecked")
        Message<Integer> reply1 = (Message<Integer>) replyChannel1.receive(1000);
        assertNotNull(reply1);
        assertThat(reply1.getPayload(), is(105));
        @SuppressWarnings("unchecked")
        Message<Integer> reply2 = (Message<Integer>) replyChannel2.receive(1000);
        assertNotNull(reply2);
        assertThat(reply2.getPayload(), is(2431));
    }

    @Test
    @Ignore
    // dropped backwards compatibility for setting capacity limit (it's always Integer.MAX_VALUE)
    public void testTrackedCorrelationIdsCapacityAtLimit() {
        QueueChannel replyChannel = new QueueChannel();
        QueueChannel discardChannel = new QueueChannel();

        this.aggregator.setDiscardChannel(discardChannel);
        this.aggregator.handleMessage(createMessage(1, 1, 1, 1, replyChannel, null));
        assertEquals(1, replyChannel.receive(1000).getPayload());
        this.aggregator.handleMessage(createMessage(3, 2, 1, 1, replyChannel, null));
        assertEquals(3, replyChannel.receive(1000).getPayload());
        this.aggregator.handleMessage(createMessage(4, 3, 1, 1, replyChannel, null));
        assertEquals(4, replyChannel.receive(1000).getPayload());
        // next message with same correllation ID is discarded
        this.aggregator.handleMessage(createMessage(2, 1, 1, 1, replyChannel, null));
        assertEquals(2, discardChannel.receive(1000).getPayload());
    }

    @Test
    @Ignore
    // dropped backwards compatibility for setting capacity limit (it's always Integer.MAX_VALUE)
    public void testTrackedCorrelationIdsCapacityPassesLimit() {
        QueueChannel replyChannel = new QueueChannel();
        QueueChannel discardChannel = new QueueChannel();

        this.aggregator.setDiscardChannel(discardChannel);
        this.aggregator.handleMessage(createMessage(1, 1, 1, 1, replyChannel, null));
        assertEquals(1, replyChannel.receive(1000).getPayload());
        this.aggregator.handleMessage(createMessage(2, 2, 1, 1, replyChannel, null));
        assertEquals(2, replyChannel.receive(1000).getPayload());
        this.aggregator.handleMessage(createMessage(3, 3, 1, 1, replyChannel, null));
        assertEquals(3, replyChannel.receive(1000).getPayload());
        this.aggregator.handleMessage(createMessage(4, 4, 1, 1, replyChannel, null));
        assertEquals(4, replyChannel.receive(1000).getPayload());
        this.aggregator.handleMessage(createMessage(5, 1, 1, 1, replyChannel, null));
        assertEquals(5, replyChannel.receive(1000).getPayload());
        assertNull(discardChannel.receive(0));
    }

    @Test(expected = MessageHandlingException.class)
    public void testExceptionThrownIfNoCorrelationId() throws InterruptedException {
        Message<?> message = createMessage(3, null, 2, 1, new QueueChannel(), null);
        this.aggregator.handleMessage(message);
    }

    @Test
    public void testAdditionalMessageAfterCompletion() throws InterruptedException {
        QueueChannel replyChannel = new QueueChannel();
        Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
        Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
        Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null);
        Message<?> message4 = createMessage(7, "ABC", 3, 3, replyChannel, null);

        this.aggregator.handleMessage(message1);
        this.aggregator.handleMessage(message2);
        this.aggregator.handleMessage(message3);
        this.aggregator.handleMessage(message4);

        Message<?> reply = replyChannel.receive(10000);
        assertNotNull("A message should be aggregated", reply);
        assertThat(((Integer) reply.getPayload()), is(105));
    }

    @Test
    public void shouldRejectDuplicatedSequenceNumbers() throws InterruptedException {
        QueueChannel replyChannel = new QueueChannel();
        Message<?> message1 = createMessage(3, "ABC", 3, 1, replyChannel, null);
        Message<?> message2 = createMessage(5, "ABC", 3, 2, replyChannel, null);
        Message<?> message3 = createMessage(7, "ABC", 3, 3, replyChannel, null);
        Message<?> message4 = createMessage(7, "ABC", 3, 3, replyChannel, null);
        this.aggregator.setReleaseStrategy(new SequenceSizeReleaseStrategy());

        this.aggregator.handleMessage(message1);
        this.aggregator.handleMessage(message3);
        // duplicated sequence number, either message3 or message4 should be rejected
        this.aggregator.handleMessage(message4);
        this.aggregator.handleMessage(message2);

        Message<?> reply = replyChannel.receive(10000);
        assertNotNull("A message should be aggregated", reply);
        assertThat(((Integer) reply.getPayload()), is(105));
    }

    private static Message<?> createMessage(Object payload, Object correlationId, int sequenceSize,
            int sequenceNumber, MessageChannel replyChannel, String predefinedId) {
        MessageBuilder<Object> builder = MessageBuilder.withPayload(payload).setCorrelationId(correlationId)
                .setSequenceSize(sequenceSize).setSequenceNumber(sequenceNumber).setReplyChannel(replyChannel);
        if (predefinedId != null) {
            builder.setHeader(MessageHeaders.ID, predefinedId);
        }
        return builder.build();
    }

    private class MultiplyingProcessor implements MessageGroupProcessor {

        MultiplyingProcessor() {
            super();
        }

        @Override
        public Object processMessageGroup(MessageGroup group) {
            Integer product = 1;
            for (Message<?> message : group.getMessages()) {
                product *= (Integer) message.getPayload();
            }
            return product;
        }
    }

}