Java tutorial
/******************************************************************************* * Copyright (c) 2012-2014 Codenvy, S.A. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * Codenvy, S.A. - initial API and implementation *******************************************************************************/ package org.everrest.websockets.client; import org.apache.commons.codec.binary.Base64; import org.everrest.core.util.Logger; import org.everrest.websockets.message.MessageConverter; import java.io.BufferedReader; import java.io.ByteArrayOutputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.PrintWriter; import java.net.Socket; import java.net.SocketTimeoutException; import java.net.URI; import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; /** * @author andrew00x */ public class WSClient { /** Max size of message payload. See http://tools.ietf.org/html/rfc6455#section-5.2 */ public static final int DEFAULT_MAX_MESSAGE_PAYLOAD_SIZE = 2 * 1024 * 1024; private static final int DEFAULT_BUFFER_SIZE = 8 * 1024; private static final Logger LOG = Logger.getLogger(WSClient.class); private static final String GLOBAL_WS_SERVER_UUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; private static final Random RANDOM = new Random(); private static final Charset UTF8_CS = Charset.forName("UTF-8"); private static final char[] CHARS = new char[36]; private static final int MASK_SIZE = 4; private static final AtomicLong sequence = new AtomicLong(1); static { int i = 0; for (int c = 48; c <= 57; c++) { CHARS[i++] = (char) c; } for (int c = 97; c <= 122; c++) { CHARS[i++] = (char) c; } } private final ExecutorService executor; private final URI target; private final int maxMessagePayloadSize; private final String secWebSocketKey; private final List<ClientMessageListener> listeners; private Socket socket; private InputStream in; private OutputStream out; private ByteBuffer inputBuffer; // Thread that reads from socket check this. private volatile boolean connected; /** * Create new websocket client. * * @param target * connection URI, e.g. <i>ws://localhost:8080/websocket</i> * @param listeners * message listeners * @throws IllegalArgumentException * if any of the following conditions are met: * <ul> * <li><code>target</code> is <code>null</code></li> * <li>protocol specified in <code>target</code> not supported</li> * <li><code>listeners</code> is <code>null</code></li> * </ul> * @see #DEFAULT_MAX_MESSAGE_PAYLOAD_SIZE */ public WSClient(URI target, ClientMessageListener... listeners) { this(target, DEFAULT_MAX_MESSAGE_PAYLOAD_SIZE, listeners); } /** * Create new websocket client. * * @param target * connection URI, e.g. <i>ws://localhost:8080/websocket</i> * @param maxMessagePayloadSize * max size of data in message. If received message contains payload greater then this value IOException thrown * when read such message * @param listeners * message listeners * @throws IllegalArgumentException * if any of the following conditions are met: * <ul> * <li><code>target</code> is <code>null</code></li> * <li>protocol specified in <code>target</code> not supported</li> * <li><code>maxMessagePayloadSize</code> is zero or negative</li> * <li><code>listeners</code> is <code>null</code></li> * </ul> * @see #DEFAULT_MAX_MESSAGE_PAYLOAD_SIZE * @see MessageConverter */ public WSClient(URI target, int maxMessagePayloadSize, ClientMessageListener... listeners) { if (target == null) { throw new IllegalArgumentException("Connection URI may not be null. "); } if (!"ws".equals(target.getScheme())) { // TODO: add 'wss' support throw new IllegalArgumentException(String.format("Unsupported scheme: %s", target.getScheme())); } if (maxMessagePayloadSize < 1) { throw new IllegalArgumentException( String.format("Invalid max message payload size: %d", maxMessagePayloadSize)); } if (listeners == null) { throw new IllegalArgumentException("listeners may not be null. "); } this.target = target; this.maxMessagePayloadSize = maxMessagePayloadSize; executor = Executors.newSingleThreadExecutor(new ThreadFactory() { @Override public Thread newThread(Runnable r) { final Thread t = new Thread(r, "everrest.WSClient" + sequence.getAndIncrement()); t.setDaemon(true); return t; } }); this.listeners = new ArrayList<>(listeners.length); Collections.addAll(this.listeners, listeners); secWebSocketKey = generateSecKey(); } public URI getUri() { return target; } public synchronized boolean isConnected() { return connected; } /** * Connect to remote server. * * @param timeout * connection timeout value in seconds * @throws IOException * if connection failed * @throws IllegalArgumentException * if <code>timeout</code> zero or negative */ public synchronized void connect(long timeout) throws IOException { if (timeout < 1) { throw new IllegalArgumentException(String.format("Invalid timeout: %d", timeout)); } if (connected) { throw new IOException("Already connected."); } try { executor.submit(new Runnable() { @Override public void run() { try { int port = target.getPort(); if (port == -1) { port = 80; } socket = new Socket(target.getHost(), port); in = socket.getInputStream(); out = socket.getOutputStream(); out.write(getHandshake()); validateResponseHeaders(); inputBuffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE); connected = true; } catch (IOException e) { throw new RuntimeException(e); } } }).get(timeout, TimeUnit.SECONDS); // Wait for connection. } catch (InterruptedException e) { // throw new IOException(e.getMessage(), e); } catch (ExecutionException e) { // It is RuntimeException for sure. RuntimeException re = (RuntimeException) e.getCause(); Throwable cause = re.getCause(); if (cause instanceof IOException) { throw (IOException) cause; } throw re; } catch (TimeoutException e) { // Connection time out reached. throw new SocketTimeoutException("Connection timeout. "); } finally { if (!connected) { executor.shutdown(); } } // Start reading from socket. executor.submit(new Runnable() { @Override public void run() { try { read(); } catch (ConnectionException e) { LOG.error(e.getMessage(), e); onClose(e.status, e.getMessage()); } catch (Exception e) { // All unexpected errors represents as protocol error, status: 1002. LOG.error(e.getMessage(), e); onClose(1002, e.getMessage()); } } }); // Notify listeners about connection open. onOpen(); } /** * Close connection to remote server. Method has no effect if connection already closed. * * @throws IOException * if i/o error occurred when try to close connection. */ public synchronized void disconnect() throws IOException { if (!connected) { // Already closed or not connected. return; } writeFrame((byte) 0x88, new byte[0]); } /** * Send text message. * * @param message * text message * @throws IOException * if any i/o errors occurred * @throws IllegalArgumentException * if message is <code>null</code> */ public synchronized void send(String message) throws IOException { if (!connected) { throw new IOException("Not connected. "); } if (message == null) { throw new IllegalArgumentException("Message may not be null. "); } // Send 'text' message without fragments. writeFrame((byte) 0x81, UTF8_CS.encode(message).array()); } /** * Send bin message. * * @param message * min message * @throws IOException * if any i/o errors occurred * @throws IllegalArgumentException * if message is <code>null</code> */ public synchronized void send(byte[] message) throws IOException { if (!connected) { throw new IOException("Not connected. "); } if (message == null) { throw new IllegalArgumentException("Message may not be null. "); } // Send 'bin' message without fragments. writeFrame((byte) 0x82, message); } /** * Send ping message * * @param message * message body * @throws IOException * if any i/o errors occurred * @throws IllegalArgumentException * if message length is greater than 125 bytes */ public synchronized void ping(byte[] message) throws IOException { if (!connected) { throw new IOException("Not connected. "); } if (message == null) { message = new byte[0]; } else if (message.length > 125) { throw new IllegalArgumentException("Ping message to large, may not be greater than 125 bytes. "); } writeFrame((byte) 0x89, message); } /** * Get value for "Origin" header for sending to server when handshake. By default this method returns * <code>null</code>. * * @return value for "Origin" header for sending to server when handshake */ protected String getOrigin() { return null; } /** * Get value for "Sec-WebSocket-Protocol" header for sending to server when handshake. By default this method * returns<code>null</code>. * * @return value for "Sec-WebSocket-Protocol" header for sending to server when handshake */ protected String[] getSubProtocols() { return null; } // private byte[] getHandshake() { ByteArrayOutputStream out = new ByteArrayOutputStream(); PrintWriter handshake = new PrintWriter(out); handshake.format("GET %s HTTP/1.1\r\n", target.getPath()); final int port = target.getPort(); if (port == 80) { handshake.format("Host: %s\r\n", target.getHost()); } else { handshake.format("Host: %s:%d\r\n", target.getHost(), port); } handshake.append("Upgrade: Websocket\r\n"); handshake.append("Connection: Upgrade\r\n"); String[] subProtocol = getSubProtocols(); if (subProtocol != null && subProtocol.length > 0) { handshake.format("Sec-WebSocket-Protocol: %s\r\n", Arrays.toString(subProtocol)); } handshake.format("Sec-WebSocket-Key: %s\r\n", secWebSocketKey); handshake.format("Sec-WebSocket-Version: %d\r\n", 13); handshake.append("Sec-WebSocket-Protocol: chat\r\n"); String origin = getOrigin(); if (origin != null) { handshake.format("Origin: %s\r\n", origin); } handshake.append('\r'); handshake.append('\n'); handshake.flush(); return out.toByteArray(); } private void onOpen() { for (ClientMessageListener listener : listeners) { try { listener.onOpen(this); } catch (Exception e) { LOG.error(e.getMessage(), e); } } } private void onMessage(String message) { for (ClientMessageListener listener : listeners) { try { listener.onMessage(message); } catch (Exception e) { LOG.error(e.getMessage(), e); } } } private void onMessage(byte[] message) { for (ClientMessageListener listener : listeners) { try { listener.onMessage(message); } catch (Exception e) { LOG.error(e.getMessage(), e); } } } private void onPong(byte[] message) { for (ClientMessageListener listener : listeners) { try { listener.onPong(message); } catch (Exception e) { LOG.error(e.getMessage(), e); } } } private void onClose(int status, String message) { try { socket.close(); } catch (IOException e) { LOG.error(e.getMessage(), e); } inputBuffer.clear(); for (ClientMessageListener listener : listeners) { try { listener.onClose(status, message); } catch (Exception e) { LOG.error(e.getMessage(), e); } } listeners.clear(); executor.shutdown(); connected = false; } private String generateSecKey() { int length = RANDOM.nextInt(CHARS.length); byte[] b = new byte[length]; for (int i = 0; i < length; i++) { b[i] = (byte) CHARS[RANDOM.nextInt(CHARS.length)]; } return Base64.encodeBase64String(b); } private byte[] generateMask() { byte[] mask = new byte[MASK_SIZE]; RANDOM.nextBytes(mask); return mask; } private byte[] getLengthAsBytes(long length) { if (length <= 125) { return new byte[] { (byte) length }; } if (length <= 0xFFFF) { byte[] bytes = new byte[3]; bytes[0] = 126; bytes[1] = (byte) (length >> 8); bytes[2] = (byte) (length & 0xFF); return bytes; } byte[] bytes = new byte[9]; // Payload length never greater then max integer: (2^31)-1 bytes[0] = 127; bytes[1] = 0; bytes[2] = 0; bytes[3] = 0; bytes[4] = 0; bytes[5] = (byte) (length >> 24); bytes[6] = (byte) (length >> 16); bytes[7] = (byte) (length >> 8); bytes[8] = (byte) (length & 0xFF); return bytes; } private void validateResponseHeaders() throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(in)); String line = br.readLine(); if (line != null && !line.startsWith("HTTP/1.1 101")) { throw new IOException("Invalid server response. Expected status is 101 'Switching Protocols'. "); } Map<String, String> headers = new HashMap<>(); while (!((line = br.readLine()) == null || line.isEmpty())) { int colon = line.indexOf(':'); if (colon > 0 && colon < line.length()) { headers.put(line.substring(0, colon).trim().toLowerCase(), line.substring(colon + 1).trim()); } } // 'Upgrade' header String header = headers.get("upgrade"); if (!"websocket".equals(header)) { throw new IOException(String .format("Invalid 'Upgrade' response header. Returned '%s' but 'websocket' expected. ", header)); } // 'Connection' header header = headers.get("connection"); if (!"upgrade".equals(header)) { throw new IOException(String.format( "Invalid 'Connection' response header. Returned '%s' but 'upgrade' expected. ", header)); } // 'Sec-WebSocket-Accept' header MessageDigest md; try { md = MessageDigest.getInstance("SHA-1"); } catch (NoSuchAlgorithmException e) { // should never happen. throw new IllegalStateException(e.getMessage(), e); } md.reset(); byte[] digest = md.digest((secWebSocketKey + GLOBAL_WS_SERVER_UUID).getBytes()); final String expectedWsSecurityAccept = Base64.encodeBase64String(digest); header = headers.get("sec-websocket-accept"); if (!expectedWsSecurityAccept.equals(header)) { throw new IOException("Invalid 'Sec-WebSocket-Accept' response header."); } } private static final int TEXT = 1; private static final int BIN = 1 << 1; private int type; private void read() throws IOException { while (connected) { final int firstByte = in.read(); if (firstByte < 0) { throw new EOFException("Failed read next websocket frame, end of the stream was reached. "); } // Check most significant bit in this byte. It always set in '1' if this fragment is final fragment. // In other word each message may not be sent in more then one fragment. final boolean fin = (firstByte & 0x80) != 0; final byte opCode = (byte) (firstByte & 0x0F); byte[] payload; switch (opCode) { case 0: // continuation frame payload = readFrame(); saveInInputBuffer(payload); // Only data frames might be fragmented. Control frames may not be fragmented. // So we can't get here with any control frames, e.g. with ping/pong messages. if (fin) { if (type == TEXT) { onMessage(getStringFormInputBuffer()); } else if (type == BIN) { onMessage(getBytesFormInputBuffer()); } } break; case 1: // text frame payload = readFrame(); if (fin) { onMessage(new String(payload, UTF8_CS)); } else { saveInInputBuffer(payload); type = TEXT; } break; case 2: // binary frame payload = readFrame(); if (fin) { onMessage(payload); } else { saveInInputBuffer(payload); type = BIN; } break; case 3: case 4: case 5: case 6: case 7: // Do nothing fo this. They are reserved for further non-control frames. break; case 8: // connection close payload = readFrame(); int status; // Read status. if (payload.length > 0) { status = ((payload[0] & 0xFF) << 8); status += (payload[1] & 0xFF); } else { status = 0; // No status. } String message = null; if (!(status == 0 || status == 1000)) { // Two bytes contains status code. The rest of bytes is message. if (payload.length > 2) { message = new String(payload, 2, payload.length - 2, UTF8_CS); } LOG.warn("Close status: {}, message: {} ", status, message); } // Specification says: body is not guaranteed to be human readable. // Send body to the listeners here if server provides it and let listeners decide what to do. onClose(status, message); break; case 9: // ping payload = readFrame(); // 'pong' response for the 'ping' message. writeFrame((byte) 0x8A, payload); LOG.debug("Ping: {} ", new String(payload, UTF8_CS)); break; case 0x0A: // pong payload = readFrame(); onPong(payload); break; case 0x0B: case 0x0C: case 0x0D: case 0x0E: case 0x0F: // Do nothing fo this. break; default: throw new ConnectionException(1003, String.format("Invalid opcode: '%s' ", Integer.toHexString(opCode))); } if (socket.isClosed()) { // May be server going down, we did not receive 'close' op_code but connection is lost. onClose(1006, null); } } } private byte[] readFrame() throws IOException { // This byte contains info about message mask and about length of payload. final int secondByte = in.read(); if (secondByte < 0) { throw new EOFException("Failed read next websocket frame, end of the stream was reached. "); } final boolean masked = (secondByte & 0x80) > 0; long length = (secondByte & 0x7F); if (length == 126) { byte[] block = new byte[2]; readBlock(block); length = getPayloadLength(block); } else if (length == 127) { byte[] block = new byte[8]; readBlock(block); length = getPayloadLength(block); } byte[] mask = null; if (masked) { mask = new byte[MASK_SIZE]; readBlock(mask); } if (length > maxMessagePayloadSize) { throw new IOException(String.format("Message payload is to large, may not be greater than %d", maxMessagePayloadSize)); } // Payload may not greater then max integer: (2^31)-1 final byte[] payload = new byte[(int) length]; readBlock(payload); if (mask != null) { // Unmask payload bytes if they masked. for (int i = 0; i < payload.length; i++) { payload[i] = (byte) (payload[i] ^ mask[i % 4]); } } return payload; } private void saveInInputBuffer(byte[] frame) { final int fSize = frame.length; if (inputBuffer.remaining() < fSize) { LOG.debug("Increase input buffer: {}", fSize); final int capacity = inputBuffer.capacity() + fSize; final ByteBuffer buf = ByteBuffer.allocate(capacity); inputBuffer.flip(); buf.put(inputBuffer); inputBuffer = buf; LOG.debug("New input buffer size {}", inputBuffer.capacity()); } inputBuffer.put(frame); } private String getStringFormInputBuffer() { inputBuffer.flip(); final String str = UTF8_CS.decode(inputBuffer).toString(); inputBuffer.clear(); return str; } private byte[] getBytesFormInputBuffer() { inputBuffer.flip(); final byte[] bytes = new byte[inputBuffer.remaining()]; inputBuffer.get(bytes); inputBuffer.clear(); return bytes; } private void writeFrame(byte opCode, byte[] payload) throws IOException { // Represent length of payload data as described in section 5.2. Base Framing Protocol of RFC-6455 // See for details: http://tools.ietf.org/html/rfc6455#section-5.2 final byte[] lengthBytes = getLengthAsBytes(payload.length); // Turn on 'mask' bit. lengthBytes[0] |= 0x80; // Generate mask bytes. final byte[] mask = generateMask(); out.write(opCode); // Payload length bytes. out.write(lengthBytes); // Mask bytes. out.write(mask); for (int i = 0, length = payload.length; i < length; i++) { // Mask each byte of payload. out.write((payload[i] ^ mask[i % 4])); } out.flush(); } private long getPayloadLength(byte[] bytes) throws IOException { if (!(bytes.length == 2 || bytes.length == 8)) { // Should never happen. Caller of this method must check to full reading of byte range. throw new IOException(String.format( "Unable get payload length. Invalid length bytes. Length must be represented by 2 or 8 bytes but %d reached. ", bytes.length)); } return getLongFromBytes(bytes); } private long getLongFromBytes(byte[] bytes) throws IOException { long length = 0; for (int i = bytes.length - 1, shift = 0; i >= 0; i--, shift += 8) { length += ((bytes[i] & 0xFF) << shift); } return length; } private void readBlock(byte[] buff) throws IOException { int offset = 0; int length = buff.length; int r; while (offset < buff.length) { r = in.read(buff, offset, length - offset); if (r < 0) { throw new EOFException("Failed read next websocket frame, end of the stream was reached. "); } offset += r; } } @SuppressWarnings("serial") private static class ConnectionException extends IOException { private final int status; private ConnectionException(int status, String message) { super(message); this.status = status; } } }