org.springframework.integration.x.ip.websocket.WebSocketSerializer.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.integration.x.ip.websocket.WebSocketSerializer.java

Source

/*
 * Copyright 2002-2013 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.integration.x.ip.websocket;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.codec.binary.Base64;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.core.serializer.Serializer;
import org.springframework.integration.MessagingException;
import org.springframework.integration.ip.tcp.serializer.SoftEndOfStreamException;
import org.springframework.integration.x.ip.serializer.AbstractHttpSwitchingDeserializer;
import org.springframework.integration.x.ip.serializer.DataFrame;
import org.springframework.util.Assert;

/**
 * @author Gary Russell
 * @since 3.0
 *
 */
public class WebSocketSerializer extends AbstractHttpSwitchingDeserializer implements Serializer<Object> {

    private static final String HTTP_1_1_101_WEB_SOCKET_PROTOCOL_HANDSHAKE_SPRING_INTEGRATION = "HTTP/1.1 101 Web Socket Protocol Handshake - Spring Integration\r\n";

    private static final Set<Short> INVALID_STATUS = new HashSet<Short>(Arrays.asList((short) 1004, (short) 1005,
            (short) 1006, (short) 1012, (short) 1013, (short) 1014, (short) 1015));

    private volatile boolean server;

    private boolean validateUtf8;

    private volatile Boolean streamChecked;

    private volatile boolean nio;

    private volatile DirectFieldAccessor streamAccessor;

    public void setServer(boolean server) {
        this.server = server;
    }

    /**
     * Validate UTF-8 (required for Autobahn tests).
     * @param validateUtf8
     */
    public void setValidateUtf8(boolean validateUtf8) {
        this.validateUtf8 = validateUtf8;
    }

    @Override
    protected DataFrame createDataFrame(int type, String frameData) {
        return new WebSocketFrame(type, frameData);
    }

    @Override
    protected BasicState createState() {
        return new WebSocketState();
    }

    @Override
    public void serialize(final Object frame, OutputStream outputStream) throws IOException {
        String data = "";
        WebSocketFrame theFrame = null;
        if (frame instanceof String) {
            data = (String) frame;
            theFrame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
        } else if (frame instanceof WebSocketFrame) {
            theFrame = (WebSocketFrame) frame;
            data = theFrame.getPayload();
        }
        if (data != null && data.startsWith("HTTP/1.1")) {
            outputStream.write(data.getBytes());
            return;
        }
        int lenBytes;
        int payloadLen = this.server ? 0 : 0x80; //masked
        boolean close = theFrame.getType() == WebSocketFrame.TYPE_CLOSE;
        boolean ping = theFrame.getType() == WebSocketFrame.TYPE_PING;
        boolean pong = theFrame.getType() == WebSocketFrame.TYPE_PONG;
        byte[] bytes = theFrame.getBinary() != null ? theFrame.getBinary() : data.getBytes("UTF-8");

        int length = bytes.length;
        if (close) {
            length += 2;
        }
        if (length >= Math.pow(2, 16)) {
            lenBytes = 8;
            payloadLen |= 127;
        } else if (length > 125) {
            lenBytes = 2;
            payloadLen |= 126;
        } else {
            lenBytes = 0;
            payloadLen |= length;
        }
        int mask = (int) System.currentTimeMillis();
        ByteBuffer buffer = ByteBuffer.allocate(length + 6 + lenBytes);
        if (ping) {
            buffer.put((byte) 0x89);
        } else if (pong) {
            buffer.put((byte) 0x8a);
        } else if (close) {
            buffer.put((byte) 0x88);
        } else if (theFrame.getType() == WebSocketFrame.TYPE_DATA_BINARY) {
            buffer.put((byte) 0x82);
        } else {
            // Final fragment; text
            buffer.put((byte) 0x81);
        }
        buffer.put((byte) payloadLen);
        if (lenBytes == 2) {
            buffer.putShort((short) length);
        } else if (lenBytes == 8) {
            buffer.putLong(length);
        }

        byte[] maskBytes = new byte[4];
        if (!server) {
            buffer.putInt(mask);
            buffer.position(buffer.position() - 4);
            buffer.get(maskBytes);
        }
        if (close) {
            buffer.putShort(theFrame.getStatus());
            // TODO: mask status when client
        }
        for (int i = 0; i < bytes.length; i++) {
            if (server) {
                buffer.put(bytes[i]);
            } else {
                buffer.put((byte) (bytes[i] ^ maskBytes[i % 4]));
            }
        }
        outputStream.write(buffer.array(), 0, buffer.position());
    }

    @Override
    public DataFrame deserialize(InputStream inputStream) throws IOException {
        if (this.streamChecked == null) {
            this.nio = inputStream.getClass().getName().endsWith("TcpNioConnection$ChannelInputStream");
            this.streamAccessor = new DirectFieldAccessor(inputStream);
            this.streamChecked = Boolean.TRUE;
        }
        DataFrame frame = null;
        BasicState state = this.getState(inputStream);
        if (state != null) {
            frame = state.getPendingFrame();
        }
        while (frame == null || (frame.getPayload() == null && frame.getBinary() == null)) {
            frame = doDeserialize(inputStream, frame);
            if (frame.getPayload() == null && frame.getBinary() == null) {
                state.setPendingFrame(frame);
            }
        }
        return frame;
    }

    private DataFrame doDeserialize(InputStream inputStream, DataFrame protoFrame) throws IOException {
        List<DataFrame> headers = checkStreaming(inputStream);
        if (headers != null) {
            return headers.get(0);
        }
        int bite;
        if (logger.isDebugEnabled()) {
            logger.debug("Available to read:" + inputStream.available());
        }
        boolean done = false;
        int len = 0;
        int n = 0;
        int dataInx = 0;
        byte[] buffer = null;
        boolean fin = false;
        boolean ping = false;
        boolean pong = false;
        boolean close = false;
        boolean binary = false;
        boolean invalid = false;
        String invalidText = null;
        boolean fragmentedControl = false;
        int lenBytes = 0;
        byte[] mask = new byte[4];
        int maskInx = 0;
        int rsv = 0;
        while (!done) {
            bite = inputStream.read();
            //         logger.debug("Read:" + Integer.toHexString(bite));
            if (this.nio) {
                bite = checkclosed(bite, inputStream);
            }
            if (bite < 0 && n == 0) {
                throw new SoftEndOfStreamException("Stream closed between payloads");
            }
            checkClosure(bite);
            switch (n++) {
            case 0:
                fin = (bite & 0x80) > 0;
                rsv = (bite & 0x70) >> 4;
                bite &= 0x0f;
                switch (bite) {
                case 0x00:
                    logger.debug("Continuation, fin=" + fin);
                    if (protoFrame == null) {
                        invalid = true;
                        invalidText = "Unexpected continuation frame";
                    } else {
                        binary = protoFrame.getType() == WebSocketFrame.TYPE_DATA_BINARY;
                    }
                    this.getState(inputStream).setPendingFrame(null);
                    break;
                case 0x01:
                    logger.debug("Text, fin=" + fin);
                    if (protoFrame != null) {
                        invalid = true;
                        invalidText = "Expected continuation frame";
                    }
                    break;
                case 0x02:
                    logger.debug("Binary, fin=" + fin);
                    if (protoFrame != null) {
                        invalid = true;
                        invalidText = "Expected continuation frame";
                    }
                    binary = true;
                    break;
                case 0x08:
                    logger.debug("Close, fin=" + fin);
                    fragmentedControl = !fin;
                    close = true;
                    break;
                case 0x09:
                    ping = true;
                    binary = true;
                    fragmentedControl = !fin;
                    logger.debug("Ping, fin=" + fin);
                    break;
                case 0x0a:
                    pong = true;
                    fragmentedControl = !fin;
                    logger.debug("Pong, fin=" + fin);
                    break;
                case 0x03:
                case 0x04:
                case 0x05:
                case 0x06:
                case 0x07:
                case 0x0b:
                case 0x0c:
                case 0x0d:
                case 0x0e:
                case 0x0f:
                    invalid = true;
                    invalidText = "Reserved opcode " + Integer.toHexString(bite);
                    break;
                default:
                    throw new IOException("Unexpected opcode " + Integer.toHexString(bite));
                }
                break;
            case 1:
                if (this.server) {
                    if ((bite & 0x80) == 0) {
                        throw new IOException("Illegal: Expected masked data from client");
                    }
                    bite &= 0x7f;
                }
                if ((bite & 0x80) > 0) {
                    throw new IOException("Illegal: Received masked data from server");
                }
                if (bite < 126) {
                    len = bite;
                    buffer = new byte[len];
                } else if (bite == 126) {
                    lenBytes = 2;
                } else {
                    lenBytes = 8;
                }
                break;
            case 2:
            case 3:
            case 4:
            case 5:
                if (lenBytes > 4 && bite != 0) {
                    throw new IOException("Max supported length exceeded");
                }
            case 6:
                if (lenBytes > 3 && (bite & 0x80) > 0) {
                    throw new IOException("Max supported length exceeded");
                }
            case 7:
            case 8:
            case 9:
                if (lenBytes-- > 0) {
                    len = len << 8 | (bite & 0xff);
                    if (lenBytes == 0) {
                        buffer = new byte[len];
                    }
                    break;
                }
            default:
                if (this.server && maskInx < 4) {
                    mask[maskInx++] = (byte) bite;
                } else {
                    if (this.server) {
                        bite ^= mask[dataInx % 4];
                    }
                    buffer[dataInx++] = (byte) bite;
                }
                done = (server ? maskInx == 4 : true) && dataInx >= len;
            }
        }
        ;

        WebSocketFrame frame;

        if (fragmentedControl) {
            frame = new WebSocketFrame(WebSocketFrame.TYPE_FRAGMENTED_CONTROL, "Fragmented control frame", buffer);
        } else if (invalid) {
            frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID, invalidText, buffer);
        } else if (!fin) {
            List<byte[]> fragments = this.getState(inputStream).getFragments();
            fragments.add(buffer);
            logger.debug("Fragment");
            return new WebSocketFrame(binary ? WebSocketFrame.TYPE_DATA_BINARY : WebSocketFrame.TYPE_DATA,
                    (String) null);
        } else if (ping) {
            frame = new WebSocketFrame(WebSocketFrame.TYPE_PING, buffer);
        } else if (pong) {
            String data = new String(buffer, "UTF-8");
            frame = new WebSocketFrame(WebSocketFrame.TYPE_PONG, data);
        } else if (close) {
            String data = new String(buffer, "UTF-8");
            if (data.length() >= 2) {
                data = data.substring(2);
            }
            WebSocketFrame closeFrame = new WebSocketFrame(WebSocketFrame.TYPE_CLOSE, data);
            short status = 1000;
            if (buffer.length >= 2) {
                status = (short) ((buffer[0] << 8) | (buffer[1] & 0xff));
                closeFrame.setStatus(status);
            }
            if (buffer.length == 1 || buffer.length > 125
                    || (buffer.length > 2 && !validateUtf8IfNecessary(buffer, 2, data)) || status < 1000
                    || INVALID_STATUS.contains(status) || (status >= 1016 && status < 3000) || status >= 5000) {
                // Simply close in this case; no close reply
                ((WebSocketState) this.getState(inputStream)).setCloseInitiated(true);
            }
            frame = closeFrame;
        } else {
            List<byte[]> fragments = this.getState(inputStream).getFragments();
            if (fragments.size() == 0) {
                if (binary) {
                    frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA_BINARY, buffer);
                } else {
                    String data = new String(buffer, "UTF-8");
                    if (!validateUtf8IfNecessary(buffer, 0, data)) {
                        frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID_UTF8, "Invalid UTF-8", buffer);
                    } else {
                        frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
                    }
                }
            } else {
                fragments.add(buffer);
                int utf8Len = 0;
                for (byte[] fragment : fragments) {
                    utf8Len += fragment.length;
                }
                byte[] reconstructed = new byte[utf8Len];
                int utf8Pos = 0;
                for (byte[] fragment : fragments) {
                    System.arraycopy(fragment, 0, reconstructed, utf8Pos, fragment.length);
                    utf8Pos += fragment.length;
                }
                fragments.clear();
                if (binary) {
                    frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA_BINARY, reconstructed);
                } else {
                    String data = new String(reconstructed, "UTF-8");
                    if (!validateUtf8IfNecessary(reconstructed, 0, data)) {
                        frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID_UTF8, "Invalid UTF-8",
                                reconstructed);
                    } else {
                        frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
                    }
                }
            }
        }
        if (rsv > 0) {
            frame.setRsv(rsv);
        }
        return frame;
    }

    /**
     * TODO: workaround for INT-2936
     */
    private int checkclosed(int bite, InputStream inputStream) {
        if (bite < 0) { // possibly a closed stream
            try {
                if ((Boolean) streamAccessor.getPropertyValue("isClosed") && inputStream.available() == 0) {
                    return -1;
                } else {
                    return bite & 0xff;
                }
            } catch (Exception e) {
                if (logger.isDebugEnabled()) {
                    logger.debug("Failed to check closed", e);
                }
                return bite;
            }
        } else {
            return bite;
        }
    }

    private boolean validateUtf8IfNecessary(byte[] buffer, int offset, String data) {
        if (this.validateUtf8) {
            try {
                byte[] bytes = data.getBytes("UTF-8");
                if (bytes.length != buffer.length - offset) {
                    return false;
                }
                for (int i = 0; i < bytes.length; i++) {
                    if (buffer[i + offset] != bytes[i]) {
                        return false;
                    }
                }
            } catch (UnsupportedEncodingException e) {
                throw new MessagingException("UTF-8 Conversion error");
            }
        }
        return true;
    }

    @Override
    protected void checkClosure(int bite) throws IOException {
        if (bite < 0) {
            logger.debug("Socket closed during message assembly");
            throw new IOException("Socket closed during message assembly");
        }
    }

    @Override
    public void removeState(Object inputStream) {
        super.removeState(inputStream);
    }

    public WebSocketFrame generateHandshake(WebSocketFrame frame) throws Exception {
        Assert.isTrue(frame.getType() == WebSocketFrame.TYPE_HEADERS, "Expected headers:" + frame);
        String[] headers = frame.getPayload().split("\\r\\n");
        String key = null;
        String version = null;
        for (String header : headers) {
            if (header.toLowerCase().startsWith("sec-websocket-key")) {
                key = header.split(":")[1].trim();
            } else if (header.toLowerCase().startsWith("sec-websocket-version")) {
                version = header.split(":")[1].trim();
            }
        }
        if (key == null) {
            throw new WebSocketUpgradeException("400 Bad Request: No sec-websocket-key header detected");
        } else if (!"13".equals(version)) {
            throw new WebSocketUpgradeException("426 Upgrade Required", "sec-websocket-version: 13\r\n");
        }
        String handshake = HTTP_1_1_101_WEB_SOCKET_PROTOCOL_HANDSHAKE_SPRING_INTEGRATION + "Upgrade: WebSocket\r\n"
                + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + this.generateWebSocketAccept(key)
                + "\r\n\r\n";
        return new WebSocketFrame(WebSocketFrame.TYPE_DATA, handshake);
    }

    private String generateWebSocketAccept(String key) throws NoSuchAlgorithmException {
        MessageDigest md = MessageDigest.getInstance("SHA-1");
        String toDigest = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
        byte[] acceptStringBytes = md.digest(toDigest.getBytes());
        acceptStringBytes = Base64.encodeBase64(acceptStringBytes);
        String acceptString = new String(acceptStringBytes);
        return acceptString;
    }

    public static class WebSocketState extends BasicState {

        private volatile boolean closeInitiated;

        private volatile boolean expectingPong;

        public boolean isCloseInitiated() {
            return this.closeInitiated;
        }

        public void setCloseInitiated(boolean closeInitiated) {
            this.closeInitiated = closeInitiated;
        }

        public boolean isExpectingPong() {
            return this.expectingPong;
        }

        public void setExpectingPong(boolean expectingPong) {
            this.expectingPong = expectingPong;
        }

    }
}