Java tutorial
/* * RED5 Open Source Flash Server - https://github.com/red5 * * Copyright 2006-2015 by respective authors (see below). All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.red5.net.websocket.codec; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Set; import org.apache.mina.core.buffer.IoBuffer; import org.apache.mina.core.future.IoFuture; import org.apache.mina.core.future.IoFutureListener; import org.apache.mina.core.future.WriteFuture; import org.apache.mina.core.session.IoSession; import org.apache.mina.filter.codec.CumulativeProtocolDecoder; import org.apache.mina.filter.codec.ProtocolDecoderOutput; import org.bouncycastle.util.encoders.Base64; import org.red5.net.websocket.Constants; import org.red5.net.websocket.WebSocketConnection; import org.red5.net.websocket.WebSocketException; import org.red5.net.websocket.WebSocketPlugin; import org.red5.net.websocket.WebSocketScopeManager; import org.red5.net.websocket.listener.IWebSocketDataListener; import org.red5.net.websocket.model.ConnectionType; import org.red5.net.websocket.model.HandshakeResponse; import org.red5.net.websocket.model.MessageType; import org.red5.net.websocket.model.WSMessage; import org.red5.server.plugin.PluginRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * This class handles the websocket decoding and its handshake process. A warning is loggged if WebSocket version 13 is not detected. <br /> * Decodes incoming buffers in a manner that makes the sender transparent to the decoders further up in the filter chain. If the sender is a native client then the buffer is simply passed through. If the sender is a websocket, it will extract the content out from the dataframe and parse it before passing it along the filter chain. * * @see <a href="https://developer.mozilla.org/en-US/docs/WebSockets/Writing_WebSocket_servers">Mozilla - Writing WebSocket Servers</a> * * @author Dhruv Chopra * @author Paul Gregoire */ public class WebSocketDecoder extends CumulativeProtocolDecoder { private static final Logger log = LoggerFactory.getLogger(WebSocketDecoder.class); private static final String DECODER_STATE_KEY = "decoder-state"; private static final String DECODED_MESSAGE_KEY = "decoded-message"; private static final String DECODED_MESSAGE_TYPE_KEY = "decoded-message-type"; private static final String DECODED_MESSAGE_FRAGMENTS_KEY = "decoded-message-fragments"; /** * Keeps track of the decoding state of a frame. Byte values start at -128 as a flag to indicate they are not set. */ private final class DecoderState { // keep track of fin == 0 to indicate a fragment byte fin = Byte.MIN_VALUE; byte opCode = Byte.MIN_VALUE; byte mask = Byte.MIN_VALUE; int frameLen = 0; // payload byte[] payload; @Override public String toString() { return "DecoderState [fin=" + fin + ", opCode=" + opCode + ", mask=" + mask + ", frameLen=" + frameLen + "]"; } } @Override protected boolean doDecode(IoSession session, IoBuffer in, ProtocolDecoderOutput out) throws Exception { IoBuffer resultBuffer; WebSocketConnection conn = (WebSocketConnection) session.getAttribute(Constants.CONNECTION); if (conn == null) { // first message on a new connection, check if its from a websocket or a native socket if (doHandShake(session, in)) { // websocket handshake was successful. Don't write anything to output as we want to abstract the handshake request message from the handler in.position(in.limit()); return true; } else { // message is from a native socket. Simply wrap and pass through resultBuffer = IoBuffer.wrap(in.array(), 0, in.limit()); in.position(in.limit()); out.write(resultBuffer); } } else if (conn.isWebConnection()) { // grab decoding state DecoderState decoderState = (DecoderState) session.getAttribute(DECODER_STATE_KEY); if (decoderState == null) { decoderState = new DecoderState(); session.setAttribute(DECODER_STATE_KEY, decoderState); } // there is incoming data from the websocket, decode it decodeIncommingData(in, session); // this will be null until all the fragments are collected WSMessage message = (WSMessage) session.getAttribute(DECODED_MESSAGE_KEY); if (log.isTraceEnabled()) { log.trace("State: {} message: {}", decoderState, message); } if (message != null) { // set the originating connection on the message message.setConnection(conn); // write the message out.write(message); // remove decoded message session.removeAttribute(DECODED_MESSAGE_KEY); } else { // there was not enough data in the buffer to parse return false; } } else { // session is known to be from a native socket. So simply wrap and pass through resultBuffer = IoBuffer.wrap(in.array(), 0, in.limit()); in.position(in.limit()); out.write(resultBuffer); } return true; } /** * Try parsing the message as a websocket handshake request. If it is such a request, then send the corresponding handshake response (as in Section 4.2.2 RFC 6455). */ @SuppressWarnings("unchecked") private boolean doHandShake(IoSession session, IoBuffer in) { // create the connection obj WebSocketConnection conn = new WebSocketConnection(session); // mark as secure if using ssl if (session.getFilterChain().contains("sslFilter")) { conn.setSecure(true); } try { Map<String, Object> headers = parseClientRequest(conn, new String(in.array())); if (log.isTraceEnabled()) { log.trace("Header map: {}", headers); } if (!headers.isEmpty() && headers.containsKey(Constants.WS_HEADER_KEY)) { // add the headers to the connection, they may be of use to implementers conn.setHeaders(headers); // add query string parameters if (headers.containsKey(Constants.URI_QS_PARAMETERS)) { conn.setQuerystringParameters( (Map<String, Object>) headers.remove(Constants.URI_QS_PARAMETERS)); } // check the version if (!"13".equals(headers.get(Constants.WS_HEADER_VERSION))) { log.info("Version 13 was not found in the request, communications may fail"); } // get the path String path = conn.getPath(); // get the scope manager WebSocketScopeManager manager = (WebSocketScopeManager) session.getAttribute(Constants.MANAGER); if (manager == null) { WebSocketPlugin plugin = (WebSocketPlugin) PluginRegistry.getPlugin("WebSocketPlugin"); manager = plugin.getManager(path); } // TODO add handling for extensions // TODO expand handling for protocols requested by the client, instead of just echoing back if (headers.containsKey(Constants.WS_HEADER_PROTOCOL)) { boolean protocolSupported = false; String protocol = (String) headers.get(Constants.WS_HEADER_PROTOCOL); log.debug("Protocol '{}' found in the request", protocol); // add protocol to the connection conn.setProtocol(protocol); // TODO check listeners for "protocol" support Set<IWebSocketDataListener> listeners = manager.getScope(path).getListeners(); for (IWebSocketDataListener listener : listeners) { if (listener.getProtocol().equals(protocol)) { //log.debug("Scope has listener support for the {} protocol", protocol); protocolSupported = true; break; } } log.debug("Scope listener does{} support the '{}' protocol", (protocolSupported ? "" : "n't"), protocol); } // store manager in the current session session.setAttribute(Constants.MANAGER, manager); // store connection in the current session session.setAttribute(Constants.CONNECTION, conn); // handshake is finished conn.setConnected(); // add connection to the manager manager.addConnection(conn); // prepare response and write it to the directly to the session HandshakeResponse wsResponse = buildHandshakeResponse(conn, (String) headers.get(Constants.WS_HEADER_KEY)); session.write(wsResponse); log.debug("Handshake complete"); return true; } // set connection as native / direct conn.setType(ConnectionType.DIRECT); } catch (Exception e) { // input is not a websocket handshake request log.warn("Handshake failed", e); } return false; } /** * Parse the client request and return a map containing the header contents. If the requested application is not enabled, return a 400 error. * * @param conn * @param requestData * @return map of headers * @throws WebSocketException */ private Map<String, Object> parseClientRequest(WebSocketConnection conn, String requestData) throws WebSocketException { String[] request = requestData.split("\r\n"); if (log.isTraceEnabled()) { log.trace("Request: {}", Arrays.toString(request)); } Map<String, Object> map = new HashMap<String, Object>(); for (int i = 0; i < request.length; i++) { log.trace("Request {}: {}", i, request[i]); if (request[i].startsWith("GET ") || request[i].startsWith("POST ") || request[i].startsWith("PUT ")) { // "GET /chat/room1?id=publisher1 HTTP/1.1" // split it on space String requestPath = request[i].split("\\s+")[1]; // get the path data for handShake int start = requestPath.indexOf('/'); int end = requestPath.length(); int ques = requestPath.indexOf('?'); if (ques > 0) { end = ques; } log.trace("Request path: {} to {} ques: {}", start, end, ques); String path = requestPath.substring(start, end).trim(); log.trace("Client request path: {}", path); conn.setPath(path); // check for '?' or included query string if (ques > 0) { // parse any included query string String qs = requestPath.substring(ques).trim(); log.trace("Request querystring: {}", qs); map.put(Constants.URI_QS_PARAMETERS, parseQuerystring(qs)); } // get the manager WebSocketPlugin plugin = (WebSocketPlugin) PluginRegistry.getPlugin("WebSocketPlugin"); if (plugin != null) { log.trace("Found plugin"); WebSocketScopeManager manager = plugin.getManager(path); log.trace("Manager was found? : {}", manager); // only check that the application is enabled, not the room or sub levels if (manager != null && manager.isEnabled(path)) { log.trace("Path enabled: {}", path); } else { // invalid scope or its application is not enabled, send disconnect message HandshakeResponse errResponse = build400Response(conn); WriteFuture future = conn.getSession().write(errResponse); future.addListener(new IoFutureListener<IoFuture>() { @Override public void operationComplete(IoFuture future) { // close connection future.getSession().closeOnFlush(); } }); throw new WebSocketException("Handshake failed, path not enabled"); } } else { log.warn("Plugin lookup failed"); HandshakeResponse errResponse = build400Response(conn); WriteFuture future = conn.getSession().write(errResponse); future.addListener(new IoFutureListener<IoFuture>() { @Override public void operationComplete(IoFuture future) { // close connection future.getSession().closeOnFlush(); } }); throw new WebSocketException("Handshake failed, missing plugin"); } } else if (request[i].contains(Constants.WS_HEADER_KEY)) { map.put(Constants.WS_HEADER_KEY, extractHeaderValue(request[i])); } else if (request[i].contains(Constants.WS_HEADER_VERSION)) { map.put(Constants.WS_HEADER_VERSION, extractHeaderValue(request[i])); } else if (request[i].contains(Constants.WS_HEADER_EXTENSIONS)) { map.put(Constants.WS_HEADER_EXTENSIONS, extractHeaderValue(request[i])); } else if (request[i].contains(Constants.WS_HEADER_PROTOCOL)) { map.put(Constants.WS_HEADER_PROTOCOL, extractHeaderValue(request[i])); } else if (request[i].contains(Constants.HTTP_HEADER_HOST)) { // get the host data conn.setHost(extractHeaderValue(request[i])); } else if (request[i].contains(Constants.HTTP_HEADER_ORIGIN)) { // get the origin data conn.setOrigin(extractHeaderValue(request[i])); } else if (request[i].contains(Constants.HTTP_HEADER_USERAGENT)) { map.put(Constants.HTTP_HEADER_USERAGENT, extractHeaderValue(request[i])); } } return map; } /** * Returns the trimmed header value. * * @param requestHeader * @return value */ private String extractHeaderValue(String requestHeader) { return requestHeader.substring(requestHeader.indexOf(':') + 1).trim(); } /** * Build a handshake response based on the given client key. * * @param clientKey * @return response * @throws WebSocketException */ private HandshakeResponse buildHandshakeResponse(WebSocketConnection conn, String clientKey) throws WebSocketException { byte[] accept; try { // performs the accept creation routine from RFC6455 @see <a href="http://tools.ietf.org/html/rfc6455">RFC6455</a> // concatenate the key and magic string, then SHA1 hash and base64 encode MessageDigest md = MessageDigest.getInstance("SHA1"); accept = Base64.encode(md.digest((clientKey + Constants.WEBSOCKET_MAGIC_STRING).getBytes())); } catch (NoSuchAlgorithmException e) { throw new WebSocketException("Algorithm is missing"); } // make up reply data... IoBuffer buf = IoBuffer.allocate(308); buf.setAutoExpand(true); buf.put("HTTP/1.1 101 Switching Protocols".getBytes()); buf.put(Constants.CRLF); buf.put("Upgrade: websocket".getBytes()); buf.put(Constants.CRLF); buf.put("Connection: Upgrade".getBytes()); buf.put(Constants.CRLF); buf.put("Server: Red5".getBytes()); buf.put(Constants.CRLF); buf.put("Sec-WebSocket-Version-Server: 13".getBytes()); buf.put(Constants.CRLF); buf.put(String.format("Sec-WebSocket-Origin: %s", conn.getOrigin()).getBytes()); buf.put(Constants.CRLF); buf.put(String.format("Sec-WebSocket-Location: %s", conn.getHost()).getBytes()); buf.put(Constants.CRLF); // send back extensions if enabled if (conn.hasExtensions()) { buf.put(String.format("Sec-WebSocket-Extensions: %s", conn.getExtensionsAsString()).getBytes()); buf.put(Constants.CRLF); } // send back protocol if enabled if (conn.hasProtocol()) { buf.put(String.format("Sec-WebSocket-Protocol: %s", conn.getProtocol()).getBytes()); buf.put(Constants.CRLF); } buf.put(String.format("Sec-WebSocket-Accept: %s", new String(accept)).getBytes()); buf.put(Constants.CRLF); buf.put(Constants.CRLF); // if any bytes follow this crlf, the follow-up data will be corrupted if (log.isTraceEnabled()) { log.trace("Handshake response size: {}", buf.limit()); } return new HandshakeResponse(buf); } /** * Build an HTTP 400 "Bad Request" response. * * @return response * @throws WebSocketException */ private HandshakeResponse build400Response(WebSocketConnection conn) throws WebSocketException { // make up reply data... IoBuffer buf = IoBuffer.allocate(32); buf.setAutoExpand(true); buf.put("HTTP/1.1 400 Bad Request".getBytes()); buf.put(Constants.CRLF); buf.put("Sec-WebSocket-Version-Server: 13".getBytes()); buf.put(Constants.CRLF); buf.put(Constants.CRLF); if (log.isTraceEnabled()) { log.trace("Handshake error response size: {}", buf.limit()); } return new HandshakeResponse(buf); } /** * Decode the in buffer according to the Section 5.2. RFC 6455. If there are multiple websocket dataframes in the buffer, this will parse all and return one complete decoded buffer. * * <pre> * 0 1 2 3 * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 * +-+-+-+-+-------+-+-------------+-------------------------------+ * |F|R|R|R| opcode|M| Payload len | Extended payload length | * |I|S|S|S| (4) |A| (7) | (16/64) | * |N|V|V|V| |S| | (if payload len==126/127) | * | |1|2|3| |K| | | * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + * | Extended payload length continued, if payload len == 127 | * + - - - - - - - - - - - - - - - +-------------------------------+ * | |Masking-key, if MASK set to 1 | * +-------------------------------+-------------------------------+ * | Masking-key (continued) | Payload Data | * +-------------------------------- - - - - - - - - - - - - - - - + * : Payload Data continued ... : * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + * | Payload Data continued ... | * +---------------------------------------------------------------+ * </pre> * * @param in * @param session */ public static void decodeIncommingData(IoBuffer in, IoSession session) { log.trace("Decoding: {}", in); // get decoder state DecoderState decoderState = (DecoderState) session.getAttribute(DECODER_STATE_KEY); if (decoderState.fin == Byte.MIN_VALUE) { byte frameInfo = in.get(); // get FIN (1 bit) //log.debug("frameInfo: {}", Integer.toBinaryString((frameInfo & 0xFF) + 256)); decoderState.fin = (byte) ((frameInfo >>> 7) & 1); log.trace("FIN: {}", decoderState.fin); // the next 3 bits are for RSV1-3 (not used here at the moment) // get the opcode (4 bits) decoderState.opCode = (byte) (frameInfo & 0x0f); log.trace("Opcode: {}", decoderState.opCode); // opcodes 3-7 and b-f are reserved for non-control frames } if (decoderState.mask == Byte.MIN_VALUE) { byte frameInfo2 = in.get(); // get mask bit (1 bit) decoderState.mask = (byte) ((frameInfo2 >>> 7) & 1); log.trace("Mask: {}", decoderState.mask); // get payload length (7, 7+16, 7+64 bits) decoderState.frameLen = (frameInfo2 & (byte) 0x7F); log.trace("Payload length: {}", decoderState.frameLen); if (decoderState.frameLen == 126) { decoderState.frameLen = in.getUnsignedShort(); log.trace("Payload length updated: {}", decoderState.frameLen); } else if (decoderState.frameLen == 127) { long extendedLen = in.getLong(); if (extendedLen >= Integer.MAX_VALUE) { log.error("Data frame is too large for this implementation. Length: {}", extendedLen); } else { decoderState.frameLen = (int) extendedLen; } log.trace("Payload length updated: {}", decoderState.frameLen); } } // ensure enough bytes left to fill payload, if masked add 4 additional bytes if (decoderState.frameLen + (decoderState.mask == 1 ? 4 : 0) > in.remaining()) { log.info("Not enough data available to decode, socket may be closed/closing"); } else { // if the data is masked (xor'd) if (decoderState.mask == 1) { // get the mask key byte maskKey[] = new byte[4]; for (int i = 0; i < 4; i++) { maskKey[i] = in.get(); } /* now un-mask frameLen bytes as per Section 5.3 RFC 6455 Octet i of the transformed data ("transformed-octet-i") is the XOR of octet i of the original data ("original-octet-i") with octet at index i modulo 4 of the masking key ("masking-key-octet-j"): j = i MOD 4 transformed-octet-i = original-octet-i XOR masking-key-octet-j */ decoderState.payload = new byte[decoderState.frameLen]; for (int i = 0; i < decoderState.frameLen; i++) { byte maskedByte = in.get(); decoderState.payload[i] = (byte) (maskedByte ^ maskKey[i % 4]); } } else { decoderState.payload = new byte[decoderState.frameLen]; in.get(decoderState.payload); } // if FIN == 0 we have fragments if (decoderState.fin == 0) { // store the fragment and continue IoBuffer fragments = (IoBuffer) session.getAttribute(DECODED_MESSAGE_FRAGMENTS_KEY); if (fragments == null) { fragments = IoBuffer.allocate(decoderState.frameLen); fragments.setAutoExpand(true); session.setAttribute(DECODED_MESSAGE_FRAGMENTS_KEY, fragments); // store message type since following type may be a continuation MessageType messageType = MessageType.CLOSE; switch (decoderState.opCode) { case 0: // continuation messageType = MessageType.CONTINUATION; break; case 1: // text messageType = MessageType.TEXT; break; case 2: // binary messageType = MessageType.BINARY; break; case 9: // ping messageType = MessageType.PING; break; case 0xa: // pong messageType = MessageType.PONG; break; } session.setAttribute(DECODED_MESSAGE_TYPE_KEY, messageType); } fragments.put(decoderState.payload); // remove decoder state session.removeAttribute(DECODER_STATE_KEY); } else { // create a message WSMessage message = new WSMessage(); // check for previously set type from the first fragment (if we have fragments) MessageType messageType = (MessageType) session.getAttribute(DECODED_MESSAGE_TYPE_KEY); if (messageType == null) { switch (decoderState.opCode) { case 0: // continuation messageType = MessageType.CONTINUATION; break; case 1: // text messageType = MessageType.TEXT; break; case 2: // binary messageType = MessageType.BINARY; break; case 9: // ping messageType = MessageType.PING; break; case 0xa: // pong messageType = MessageType.PONG; break; case 8: // close messageType = MessageType.CLOSE; // handler or listener should close upon receipt break; default: // TODO throw ex? log.info("Unhandled opcode: {}", decoderState.opCode); } } // set message type message.setMessageType(messageType); // check for fragments and piece them together, otherwise just send the single completed frame IoBuffer fragments = (IoBuffer) session.removeAttribute(DECODED_MESSAGE_FRAGMENTS_KEY); if (fragments != null) { fragments.put(decoderState.payload); fragments.flip(); message.setPayload(fragments); } else { // add the payload message.addPayload(decoderState.payload); } // set the message on the session session.setAttribute(DECODED_MESSAGE_KEY, message); // remove decoder state session.removeAttribute(DECODER_STATE_KEY); // remove type session.removeAttribute(DECODED_MESSAGE_TYPE_KEY); } } } /** * Returns a map of key / value pairs from a given querystring. * * @param query * @return k/v map */ public static Map<String, Object> parseQuerystring(String query) { String[] params = query.split("&"); Map<String, Object> map = new HashMap<String, Object>(); for (String param : params) { String name = param.split("=")[0]; String value = param.split("=")[1]; map.put(name, value); } return map; } }