Java tutorial
/* * Copyright 2014 Netflix, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.reactivex.netty.protocol.http.websocket; import io.netty.buffer.ByteBuf; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.netty.handler.logging.LogLevel; import io.reactivex.netty.RxNetty; import io.reactivex.netty.channel.ConnectionHandler; import io.reactivex.netty.channel.ObservableConnection; import io.reactivex.netty.server.RxServer; import org.junit.Test; import rx.Observable; import rx.functions.Action1; import rx.functions.Func1; import java.nio.charset.Charset; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author Tomasz Bak */ public class WebSocketClientServerTest { @Test public void testTextCommunication() throws Exception { TestSequenceExecutor executor = new TestSequenceExecutor() .withClientFrames(new TextWebSocketFrame("clientRequest")).withExpectedOnServer(1) .withServerFrames(new TextWebSocketFrame("serverResponse")).withExpectedOnClient(1).execute(); assertEquals("Expected original client request", "clientRequest", asText(executor.getReceivedClientFrames().get(0))); assertEquals("Expected original server response", "serverResponse", asText(executor.getReceivedServerFrames().get(0))); } @Test public void testBinaryCommunication() throws Exception { TestSequenceExecutor executor = new TestSequenceExecutor() .withClientFrames(new BinaryWebSocketFrame(toByteBuf("clientRequest"))).withExpectedOnServer(1) .withServerFrames(new BinaryWebSocketFrame(toByteBuf("serverResponse"))).withExpectedOnClient(1) .execute(); assertEquals("Expected original client request", "clientRequest", asText(executor.getReceivedClientFrames().get(0))); assertEquals("Expected original server response", "serverResponse", asText(executor.getReceivedServerFrames().get(0))); } @Test public void testFragmentedMessage() throws Exception { TestSequenceExecutor executor = new TestSequenceExecutor() .withClientFrames(new TextWebSocketFrame(false, 0, "first"), new ContinuationWebSocketFrame(false, 0, "middle"), new ContinuationWebSocketFrame(true, 0, "last")) .withExpectedOnServer(3).execute(); assertEquals("Expected first frame content", "first", asText(executor.getReceivedClientFrames().get(0))); assertEquals("Expected first frame content", "middle", asText(executor.getReceivedClientFrames().get(1))); assertEquals("Expected first frame content", "last", asText(executor.getReceivedClientFrames().get(2))); } @Test public void testMessageAggregationOnServer() throws Exception { TestSequenceExecutor executor = new TestSequenceExecutor().withMessageAggregation(true) .withClientFrames(new TextWebSocketFrame(false, 0, "0123456789"), new ContinuationWebSocketFrame(true, 0, "ABCDEFGHIJ")) .withExpectedOnServer(1).execute(); assertEquals("Expected aggregated message", "0123456789ABCDEFGHIJ", asText(executor.getReceivedClientFrames().get(0))); } @Test public void testMessageAggregationOnClient() throws Exception { TestSequenceExecutor executor = new TestSequenceExecutor().withMessageAggregation(true) .withServerFrames(new TextWebSocketFrame(false, 0, "0123456789"), new ContinuationWebSocketFrame(true, 0, "ABCDEFGHIJ")) .withExpectedOnClient(1).execute(); assertEquals("Expected aggregated message", "0123456789ABCDEFGHIJ", asText(executor.getReceivedServerFrames().get(0))); } @Test public void testPingPong() throws Exception { TestSequenceExecutor executor = new TestSequenceExecutor().withClientFrames(new PingWebSocketFrame()) .withExpectedOnServer(1).withServerFrames(new PongWebSocketFrame()).withExpectedOnClient(1) .execute(); assertTrue("Expected ping on server", executor.getReceivedClientFrames().get(0) instanceof PingWebSocketFrame); assertTrue("Expected pong on client", executor.getReceivedServerFrames().get(0) instanceof PongWebSocketFrame); } @Test public void testConnectionClose() throws Exception { TestSequenceExecutor executor = new TestSequenceExecutor() .withClientFrames(new CloseWebSocketFrame(1000, "close")).withExpectedOnServer(1) .withServerFrames(new CloseWebSocketFrame(1001, "close requested")).withExpectedOnClient(1) .execute(); assertTrue("Expected close on server", executor.getReceivedClientFrames().get(0) instanceof CloseWebSocketFrame); assertTrue("Expected close on server", executor.getReceivedServerFrames().get(0) instanceof CloseWebSocketFrame); } private static ByteBuf toByteBuf(String text) { byte[] bytes = text.getBytes(Charset.defaultCharset()); ByteBuf byteBuf = UnpooledByteBufAllocator.DEFAULT.buffer(bytes.length); return byteBuf.writeBytes(bytes); } private static String asText(WebSocketFrame frame) { return frame.content().toString(Charset.defaultCharset()); } private static class TestSequenceExecutor { private WebSocketFrame[] clientFrames; private int expectedOnServer; private WebSocketFrame[] serverFrames; private int expectedOnClient; private final List<WebSocketFrame> receivedClientFrames = new CopyOnWriteArrayList<WebSocketFrame>(); private final List<WebSocketFrame> receivedServerFrames = new CopyOnWriteArrayList<WebSocketFrame>(); private boolean messageAggregation; public List<WebSocketFrame> getReceivedClientFrames() { return receivedClientFrames; } public List<WebSocketFrame> getReceivedServerFrames() { return receivedServerFrames; } public TestSequenceExecutor withMessageAggregation(boolean messageAggregation) { this.messageAggregation = messageAggregation; return this; } public TestSequenceExecutor withClientFrames(WebSocketFrame... clientFrames) { this.clientFrames = clientFrames; return this; } public TestSequenceExecutor withExpectedOnServer(int expectedOnServer) { this.expectedOnServer = expectedOnServer; return this; } public TestSequenceExecutor withServerFrames(WebSocketFrame... serverFrames) { this.serverFrames = serverFrames; return this; } public TestSequenceExecutor withExpectedOnClient(int expectedOnClient) { this.expectedOnClient = expectedOnClient; return this; } public TestSequenceExecutor execute() throws InterruptedException, TimeoutException, ExecutionException { final CountDownLatch serverLatch = new CountDownLatch(expectedOnServer); RxServer<WebSocketFrame, WebSocketFrame> server = RxNetty .newWebSocketServerBuilder(0, new ConnectionHandler<WebSocketFrame, WebSocketFrame>() { @Override public Observable<Void> handle( final ObservableConnection<WebSocketFrame, WebSocketFrame> connection) { if (clientFrames == null) { return sendBatchOfFrames(connection, serverFrames); } return connection.getInput().flatMap(new Func1<WebSocketFrame, Observable<Void>>() { @Override public Observable<Void> call(WebSocketFrame frame) { frame.retain(); receivedClientFrames.add(frame); serverLatch.countDown(); if (serverLatch.getCount() == 0) { return sendBatchOfFrames(connection, serverFrames); } return Observable.empty(); } }); } }).withMessageAggregator(messageAggregation).enableWireLogging(LogLevel.ERROR).build().start(); final CountDownLatch clientLatch = new CountDownLatch(expectedOnClient); RxNetty.newWebSocketClientBuilder("localhost", server.getServerPort()) .withWebSocketVersion(WebSocketVersion.V13).withMessageAggregation(messageAggregation) .enableWireLogging(LogLevel.ERROR).build().connect().flatMap( new Func1<ObservableConnection<WebSocketFrame, WebSocketFrame>, Observable<WebSocketFrame>>() { @Override public Observable<WebSocketFrame> call( final ObservableConnection<WebSocketFrame, WebSocketFrame> connection) { sendBatchOfFrames(connection, clientFrames); return connection.getInput().doOnNext(new Action1<WebSocketFrame>() { @Override public void call(WebSocketFrame webSocketFrame) { webSocketFrame.retain(); } }); } }) .subscribe(new Action1<WebSocketFrame>() { @Override public void call(WebSocketFrame webSocketFrame) { receivedServerFrames.add(webSocketFrame); clientLatch.countDown(); } }); assertTrue("Timeout on server", serverLatch.await(30, TimeUnit.SECONDS)); assertTrue("Timeout on client", clientLatch.await(30, TimeUnit.SECONDS)); server.shutdown(); assertEquals("Invalid number of server frames received", expectedOnClient, receivedServerFrames.size()); assertEquals("Invalid number of client frames received", expectedOnServer, receivedClientFrames.size()); return this; } private static Observable<Void> sendBatchOfFrames( ObservableConnection<WebSocketFrame, WebSocketFrame> connection, WebSocketFrame[] frames) { if (frames != null) { for (WebSocketFrame frame : frames) { connection.write(frame); } return connection.flush(); } return Observable.empty(); } } }