io.reactivex.netty.protocol.http.websocket.WebSocketClientServerTest.java Source code

Java tutorial

Introduction

Here is the source code for io.reactivex.netty.protocol.http.websocket.WebSocketClientServerTest.java

Source

/*
 * Copyright 2014 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.reactivex.netty.protocol.http.websocket;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.logging.LogLevel;
import io.reactivex.netty.RxNetty;
import io.reactivex.netty.channel.ConnectionHandler;
import io.reactivex.netty.channel.ObservableConnection;
import io.reactivex.netty.server.RxServer;
import org.junit.Test;
import rx.Observable;
import rx.functions.Action1;
import rx.functions.Func1;

import java.nio.charset.Charset;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
 * @author Tomasz Bak
 */
public class WebSocketClientServerTest {

    @Test
    public void testTextCommunication() throws Exception {
        TestSequenceExecutor executor = new TestSequenceExecutor()
                .withClientFrames(new TextWebSocketFrame("clientRequest")).withExpectedOnServer(1)
                .withServerFrames(new TextWebSocketFrame("serverResponse")).withExpectedOnClient(1).execute();

        assertEquals("Expected original client request", "clientRequest",
                asText(executor.getReceivedClientFrames().get(0)));
        assertEquals("Expected original server response", "serverResponse",
                asText(executor.getReceivedServerFrames().get(0)));
    }

    @Test
    public void testBinaryCommunication() throws Exception {
        TestSequenceExecutor executor = new TestSequenceExecutor()
                .withClientFrames(new BinaryWebSocketFrame(toByteBuf("clientRequest"))).withExpectedOnServer(1)
                .withServerFrames(new BinaryWebSocketFrame(toByteBuf("serverResponse"))).withExpectedOnClient(1)
                .execute();

        assertEquals("Expected original client request", "clientRequest",
                asText(executor.getReceivedClientFrames().get(0)));
        assertEquals("Expected original server response", "serverResponse",
                asText(executor.getReceivedServerFrames().get(0)));
    }

    @Test
    public void testFragmentedMessage() throws Exception {
        TestSequenceExecutor executor = new TestSequenceExecutor()
                .withClientFrames(new TextWebSocketFrame(false, 0, "first"),
                        new ContinuationWebSocketFrame(false, 0, "middle"),
                        new ContinuationWebSocketFrame(true, 0, "last"))
                .withExpectedOnServer(3).execute();

        assertEquals("Expected first frame content", "first", asText(executor.getReceivedClientFrames().get(0)));
        assertEquals("Expected first frame content", "middle", asText(executor.getReceivedClientFrames().get(1)));
        assertEquals("Expected first frame content", "last", asText(executor.getReceivedClientFrames().get(2)));
    }

    @Test
    public void testMessageAggregationOnServer() throws Exception {
        TestSequenceExecutor executor = new TestSequenceExecutor().withMessageAggregation(true)
                .withClientFrames(new TextWebSocketFrame(false, 0, "0123456789"),
                        new ContinuationWebSocketFrame(true, 0, "ABCDEFGHIJ"))
                .withExpectedOnServer(1).execute();
        assertEquals("Expected aggregated message", "0123456789ABCDEFGHIJ",
                asText(executor.getReceivedClientFrames().get(0)));
    }

    @Test
    public void testMessageAggregationOnClient() throws Exception {
        TestSequenceExecutor executor = new TestSequenceExecutor().withMessageAggregation(true)
                .withServerFrames(new TextWebSocketFrame(false, 0, "0123456789"),
                        new ContinuationWebSocketFrame(true, 0, "ABCDEFGHIJ"))
                .withExpectedOnClient(1).execute();
        assertEquals("Expected aggregated message", "0123456789ABCDEFGHIJ",
                asText(executor.getReceivedServerFrames().get(0)));
    }

    @Test
    public void testPingPong() throws Exception {
        TestSequenceExecutor executor = new TestSequenceExecutor().withClientFrames(new PingWebSocketFrame())
                .withExpectedOnServer(1).withServerFrames(new PongWebSocketFrame()).withExpectedOnClient(1)
                .execute();

        assertTrue("Expected ping on server",
                executor.getReceivedClientFrames().get(0) instanceof PingWebSocketFrame);
        assertTrue("Expected pong on client",
                executor.getReceivedServerFrames().get(0) instanceof PongWebSocketFrame);
    }

    @Test
    public void testConnectionClose() throws Exception {
        TestSequenceExecutor executor = new TestSequenceExecutor()
                .withClientFrames(new CloseWebSocketFrame(1000, "close")).withExpectedOnServer(1)
                .withServerFrames(new CloseWebSocketFrame(1001, "close requested")).withExpectedOnClient(1)
                .execute();

        assertTrue("Expected close on server",
                executor.getReceivedClientFrames().get(0) instanceof CloseWebSocketFrame);
        assertTrue("Expected close on server",
                executor.getReceivedServerFrames().get(0) instanceof CloseWebSocketFrame);
    }

    private static ByteBuf toByteBuf(String text) {
        byte[] bytes = text.getBytes(Charset.defaultCharset());
        ByteBuf byteBuf = UnpooledByteBufAllocator.DEFAULT.buffer(bytes.length);
        return byteBuf.writeBytes(bytes);
    }

    private static String asText(WebSocketFrame frame) {
        return frame.content().toString(Charset.defaultCharset());
    }

    private static class TestSequenceExecutor {
        private WebSocketFrame[] clientFrames;
        private int expectedOnServer;
        private WebSocketFrame[] serverFrames;
        private int expectedOnClient;

        private final List<WebSocketFrame> receivedClientFrames = new CopyOnWriteArrayList<WebSocketFrame>();
        private final List<WebSocketFrame> receivedServerFrames = new CopyOnWriteArrayList<WebSocketFrame>();
        private boolean messageAggregation;

        public List<WebSocketFrame> getReceivedClientFrames() {
            return receivedClientFrames;
        }

        public List<WebSocketFrame> getReceivedServerFrames() {
            return receivedServerFrames;
        }

        public TestSequenceExecutor withMessageAggregation(boolean messageAggregation) {
            this.messageAggregation = messageAggregation;
            return this;
        }

        public TestSequenceExecutor withClientFrames(WebSocketFrame... clientFrames) {
            this.clientFrames = clientFrames;
            return this;
        }

        public TestSequenceExecutor withExpectedOnServer(int expectedOnServer) {
            this.expectedOnServer = expectedOnServer;
            return this;
        }

        public TestSequenceExecutor withServerFrames(WebSocketFrame... serverFrames) {
            this.serverFrames = serverFrames;
            return this;
        }

        public TestSequenceExecutor withExpectedOnClient(int expectedOnClient) {
            this.expectedOnClient = expectedOnClient;
            return this;
        }

        public TestSequenceExecutor execute() throws InterruptedException, TimeoutException, ExecutionException {
            final CountDownLatch serverLatch = new CountDownLatch(expectedOnServer);
            RxServer<WebSocketFrame, WebSocketFrame> server = RxNetty
                    .newWebSocketServerBuilder(0, new ConnectionHandler<WebSocketFrame, WebSocketFrame>() {
                        @Override
                        public Observable<Void> handle(
                                final ObservableConnection<WebSocketFrame, WebSocketFrame> connection) {
                            if (clientFrames == null) {
                                return sendBatchOfFrames(connection, serverFrames);
                            }
                            return connection.getInput().flatMap(new Func1<WebSocketFrame, Observable<Void>>() {
                                @Override
                                public Observable<Void> call(WebSocketFrame frame) {
                                    frame.retain();
                                    receivedClientFrames.add(frame);
                                    serverLatch.countDown();
                                    if (serverLatch.getCount() == 0) {
                                        return sendBatchOfFrames(connection, serverFrames);
                                    }
                                    return Observable.empty();
                                }
                            });
                        }
                    }).withMessageAggregator(messageAggregation).enableWireLogging(LogLevel.ERROR).build().start();

            final CountDownLatch clientLatch = new CountDownLatch(expectedOnClient);
            RxNetty.newWebSocketClientBuilder("localhost", server.getServerPort())
                    .withWebSocketVersion(WebSocketVersion.V13).withMessageAggregation(messageAggregation)
                    .enableWireLogging(LogLevel.ERROR).build().connect().flatMap(
                            new Func1<ObservableConnection<WebSocketFrame, WebSocketFrame>, Observable<WebSocketFrame>>() {
                                @Override
                                public Observable<WebSocketFrame> call(
                                        final ObservableConnection<WebSocketFrame, WebSocketFrame> connection) {
                                    sendBatchOfFrames(connection, clientFrames);
                                    return connection.getInput().doOnNext(new Action1<WebSocketFrame>() {
                                        @Override
                                        public void call(WebSocketFrame webSocketFrame) {
                                            webSocketFrame.retain();
                                        }
                                    });
                                }
                            })
                    .subscribe(new Action1<WebSocketFrame>() {
                        @Override
                        public void call(WebSocketFrame webSocketFrame) {
                            receivedServerFrames.add(webSocketFrame);
                            clientLatch.countDown();
                        }
                    });
            assertTrue("Timeout on server", serverLatch.await(30, TimeUnit.SECONDS));
            assertTrue("Timeout on client", clientLatch.await(30, TimeUnit.SECONDS));

            server.shutdown();

            assertEquals("Invalid number of server frames received", expectedOnClient, receivedServerFrames.size());
            assertEquals("Invalid number of client frames received", expectedOnServer, receivedClientFrames.size());

            return this;
        }

        private static Observable<Void> sendBatchOfFrames(
                ObservableConnection<WebSocketFrame, WebSocketFrame> connection, WebSocketFrame[] frames) {
            if (frames != null) {
                for (WebSocketFrame frame : frames) {
                    connection.write(frame);
                }
                return connection.flush();
            }
            return Observable.empty();
        }
    }
}