org.apache.weasel.V06Handshake.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.weasel.V06Handshake.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.    
 */
package org.apache.weasel;

import java.io.BufferedReader;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.GatheringByteChannel;
import java.nio.channels.ScatteringByteChannel;
import java.nio.channels.SelectableChannel;
import java.nio.charset.Charset;
import java.util.Map;
import java.util.Random;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.digest.DigestUtils;

public class V06Handshake<T extends SelectableChannel & GatheringByteChannel & ScatteringByteChannel>
        implements WebSocketHandshake<T> {

    private static final Charset UTF8 = Charset.forName("UTF-8");
    private static final String HOST_HEADER = "Host";
    private static final String UPGRADE_HEADER = "Upgrade";
    private static final String CONNECTION_HEADER = "Connection";
    private static final String KEY_HEADER = "Sec-WebSocket-Key";
    private static final String ORIGIN_HEADER = "Sec-WebSocket-Origin";
    private static final String PROTOCOL_HEADER = "Sec-WebSocket-Protocol";
    private static final String VERSION_HEADER = "Sec-WebSocket-Version";
    private static final String VERSION = "6";
    private static final String ACCEPT_HEADER = "Sec-WebSocket-Accept";
    private static final String COOKIE_HEADER = "Cookie";
    private static final String[] MANDATORY_HEADERS = { HOST_HEADER, KEY_HEADER, VERSION_HEADER };
    private static final String SERVER_KEY_ADDON = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    private static final String CRLF = "\r\n";

    public WebSocket<T> clientHandshake(T channel, URI uri, String subprotocol, String origin, String cookies,
            String... otherHeaders) throws IOException {
        // TODO: support non-blocking
        String key = generateKey();
        ByteBuffer request = buildRequest(uri, key, origin, subprotocol, cookies, otherHeaders);
        channel.write(request);
        BufferedReader reader = new BufferedReader(Channels.newReader(channel, "UTF-8"));
        if (!processResponse(reader, key)) {
            channel.close();
            throw new WebSocketException("Handshake failed");
        }
        return new V06WebSocket<T>(channel, true);
    }

    private ByteBuffer buildRequest(URI uri, String key, String origin, String subprotocol, String cookies,
            String... otherHeaders) {
        String path = uri.getPath().isEmpty() ? "/" : uri.getPath();
        StringBuilder request = new StringBuilder();
        request.append("GET ").append(path).append(" HTTP/1.1").append(CRLF);
        request.append(HOST_HEADER).append(": ").append(uri.getHost()).append(CRLF);
        request.append(UPGRADE_HEADER).append(": websocket").append(CRLF);
        request.append(CONNECTION_HEADER).append(": Upgrade").append(CRLF);
        request.append(KEY_HEADER).append(": ").append(key).append(CRLF);
        if (origin != null && !origin.isEmpty()) {
            request.append(ORIGIN_HEADER).append(": ").append(origin).append(CRLF);
        }
        if (subprotocol != null && !subprotocol.isEmpty()) {
            request.append(PROTOCOL_HEADER).append(": ").append(subprotocol).append(CRLF);
        }
        request.append(VERSION_HEADER).append(": ").append(VERSION).append(CRLF);
        if (cookies != null && !cookies.isEmpty()) {
            request.append(COOKIE_HEADER).append(": ").append(cookies).append(CRLF);
        }
        // TODO is this the expected format of otherHeaders?
        if (otherHeaders != null) {
            for (String header : otherHeaders) {
                request.append(header).append(CRLF);
            }
        }
        request.append(CRLF);
        return ByteBuffer.wrap(request.toString().getBytes(UTF8));
    }

    private boolean processResponse(BufferedReader reader, String key) throws IOException {
        String responseLine = reader.readLine();
        Map<String, String> headers = HttpUtils.readHttpHeaders(reader);
        DebugUtils.printHttpMessage(responseLine, headers);
        return checkResponseLine(responseLine) && checkResponseHeaders(headers, key);
    }

    private boolean checkResponseLine(String responseLine) {
        return responseLine.startsWith("HTTP/1.1 101");
    }

    private boolean checkResponseHeaders(Map<String, String> headers, String key) {
        String expectedKey = new String(Base64.encodeBase64(DigestUtils.sha(key + SERVER_KEY_ADDON)));
        return "websocket".equalsIgnoreCase(headers.get(UPGRADE_HEADER.toLowerCase()))
                && "Upgrade".equalsIgnoreCase(headers.get(CONNECTION_HEADER.toLowerCase()))
                && expectedKey.equals(headers.get(ACCEPT_HEADER.toLowerCase()));
    }

    private String generateKey() {
        Random random = new Random(System.currentTimeMillis());
        byte[] key = new byte[16];
        random.nextBytes(key);
        return new String(Base64.encodeBase64(key));
    }

    public String getRealPath(String requestLine, Map<String, String> headers) {
        return requestLine.split(" ")[1];
    }

    public boolean matches(String requestLine, Map<String, String> headers) {
        return matchesRequestLine(requestLine) && matchesHeaders(headers);
    }

    private boolean matchesRequestLine(String requestLine) {
        String[] tokens = requestLine.split(" ");
        return tokens.length == 3 && tokens[0].equals("GET") && tokens[1].startsWith("/")
                && tokens[2].equals("HTTP/1.1");
    }

    private boolean matchesHeaders(Map<String, String> headers) {
        for (String headerKey : MANDATORY_HEADERS) {
            if (!headers.containsKey(headerKey.toLowerCase()) || headers.get(headerKey.toLowerCase()) == null
                    || headers.get(headerKey.toLowerCase()).isEmpty()) {
                return false;
            }
        }
        byte[] key = headers.get(KEY_HEADER.toLowerCase()).getBytes();
        if (Base64.decodeBase64(key).length != 16) {
            return false;
        }
        if (!headers.get(VERSION_HEADER.toLowerCase()).equals(VERSION)) {
            return false;
        }
        return true;
    }

    public WebSocket<T> serverHandshake(T channel, String requestLine, Map<String, String> headers,
            WebSocketApplication<T> app) throws IOException {
        StringBuilder response = new StringBuilder();
        response.append("HTTP/1.1 101 Switching Protocols").append(CRLF);
        String key = headers.get(KEY_HEADER.toLowerCase()) + SERVER_KEY_ADDON;
        byte[] sha1AcceptKey = DigestUtils.sha(key);
        byte[] base64AceeptKey = Base64.encodeBase64(sha1AcceptKey);
        response.append(ACCEPT_HEADER).append(": ").append(new String(base64AceeptKey)).append(CRLF);
        response.append(UPGRADE_HEADER).append(": ").append("websocket").append(CRLF);
        response.append(CONNECTION_HEADER).append(": ").append("Upgrade").append(CRLF);
        String protocol = headers.get(PROTOCOL_HEADER.toLowerCase());
        String acceptedProtocol = app.acceptProtocol(protocol);
        if (acceptedProtocol != null && !protocol.isEmpty()) {
            response.append(PROTOCOL_HEADER).append(": ").append(acceptedProtocol).append(CRLF);
        }
        String origin = headers.get(ORIGIN_HEADER.toLowerCase());
        if (!app.acceptOrigin(origin)) {
            throw new WebSocketException("Origin not accepted");
        }
        // TODO: add extension abstraction
        Map<String, String> responseHeaders = app.acceptExtensions(headers);
        if (responseHeaders != null) {
            for (String header : responseHeaders.keySet()) {
                response.append(header).append(": ").append(responseHeaders.get(header)).append(CRLF);
            }
        }
        response.append(CRLF);
        channel.write(ByteBuffer.wrap(response.toString().getBytes(UTF8)));
        // TODO determine that client received the response and didn't fail the
        // connection
        return new V06WebSocket<T>(channel, false);
    }
}