Java tutorial
/* * Copyright 2015 Twitter, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.twitter.http2; import java.nio.charset.StandardCharsets; import io.netty.buffer.ByteBuf; import static com.twitter.http2.HttpCodecUtil.HTTP_CONTINUATION_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_DATA_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_FLAG_ACK; import static com.twitter.http2.HttpCodecUtil.HTTP_FLAG_END_HEADERS; import static com.twitter.http2.HttpCodecUtil.HTTP_FLAG_END_SEGMENT; import static com.twitter.http2.HttpCodecUtil.HTTP_FLAG_END_STREAM; import static com.twitter.http2.HttpCodecUtil.HTTP_FLAG_PADDED; import static com.twitter.http2.HttpCodecUtil.HTTP_FLAG_PRIORITY; import static com.twitter.http2.HttpCodecUtil.HTTP_FRAME_HEADER_SIZE; import static com.twitter.http2.HttpCodecUtil.HTTP_GOAWAY_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_HEADERS_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_MAX_LENGTH; import static com.twitter.http2.HttpCodecUtil.HTTP_PING_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_PRIORITY_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_PUSH_PROMISE_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_RST_STREAM_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_SETTINGS_FRAME; import static com.twitter.http2.HttpCodecUtil.HTTP_WINDOW_UPDATE_FRAME; import static com.twitter.http2.HttpCodecUtil.getSignedInt; import static com.twitter.http2.HttpCodecUtil.getSignedLong; import static com.twitter.http2.HttpCodecUtil.getUnsignedInt; import static com.twitter.http2.HttpCodecUtil.getUnsignedMedium; import static com.twitter.http2.HttpCodecUtil.getUnsignedShort; /** * Decodes {@link ByteBuf}s into HTTP/2 Frames. */ public class HttpFrameDecoder { private static final byte[] CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" .getBytes(StandardCharsets.US_ASCII); private final int maxChunkSize; private final HttpFrameDecoderDelegate delegate; private State state; // HTTP/2 frame header fields private int length; private short type; private byte flags; private int streamId; // HTTP/2 frame padding length private int paddingLength; private enum State { READ_CONNECTION_HEADER, READ_FRAME_HEADER, READ_PADDING_LENGTH, READ_DATA_FRAME, READ_DATA, READ_HEADERS_FRAME, READ_PRIORITY_FRAME, READ_RST_STREAM_FRAME, READ_SETTINGS_FRAME, READ_SETTING, READ_PUSH_PROMISE_FRAME, READ_PING_FRAME, READ_GOAWAY_FRAME, READ_WINDOW_UPDATE_FRAME, READ_CONTINUATION_FRAME_HEADER, READ_HEADER_BLOCK, SKIP_FRAME_PADDING, SKIP_FRAME_PADDING_CONTINUATION, FRAME_ERROR } /** * Creates a new instance with the specified @{code HttpFrameDecoderDelegate} * and the default {@code maxChunkSize (8192)}. */ public HttpFrameDecoder(boolean server, HttpFrameDecoderDelegate delegate) { this(server, delegate, 8192); } /** * Creates a new instance with the specified parameters. */ public HttpFrameDecoder(boolean server, HttpFrameDecoderDelegate delegate, int maxChunkSize) { if (delegate == null) { throw new NullPointerException("delegate"); } if (maxChunkSize <= 0) { throw new IllegalArgumentException("maxChunkSize must be a positive integer: " + maxChunkSize); } this.delegate = delegate; this.maxChunkSize = maxChunkSize; if (server) { state = State.READ_CONNECTION_HEADER; } else { state = State.READ_FRAME_HEADER; } } /** * Decode the byte buffer. */ public void decode(ByteBuf buffer) { boolean endStream; boolean endSegment; int minLength; int dependency; int weight; boolean exclusive; int errorCode; while (true) { switch (state) { case READ_CONNECTION_HEADER: while (buffer.isReadable()) { byte b = buffer.readByte(); if (b != CLIENT_CONNECTION_PREFACE[length++]) { state = State.FRAME_ERROR; delegate.readFrameError("Invalid Connection Header"); return; } if (length == CLIENT_CONNECTION_PREFACE.length) { state = State.READ_FRAME_HEADER; break; } } if (buffer.isReadable()) { break; } else { return; } case READ_FRAME_HEADER: // Wait until entire header is readable if (buffer.readableBytes() < HTTP_FRAME_HEADER_SIZE) { return; } // Read frame header fields readFrameHeader(buffer); // TODO(jpinner) Sec 4.2 FRAME_SIZE_ERROR if (!isValidFrameHeader(length, type, flags, streamId)) { state = State.FRAME_ERROR; delegate.readFrameError("Invalid Frame Header"); } else if (frameHasPadding(type, flags)) { state = State.READ_PADDING_LENGTH; } else { paddingLength = 0; state = getNextState(length, type); } break; case READ_PADDING_LENGTH: if (buffer.readableBytes() < 1) { return; } paddingLength = buffer.readUnsignedByte(); --length; if (!isValidPaddingLength(length, type, flags, paddingLength)) { state = State.FRAME_ERROR; delegate.readFrameError("Invalid Frame Padding Length"); } else { state = getNextState(length, type); } break; case READ_DATA_FRAME: endStream = hasFlag(flags, HTTP_FLAG_END_STREAM); state = State.READ_DATA; if (hasFlag(flags, HTTP_FLAG_PADDED)) { delegate.readDataFramePadding(streamId, endStream, paddingLength + 1); } break; case READ_DATA: // Generate data frames that do not exceed maxChunkSize // maxChunkSize must be > 0 so we cannot infinitely loop int dataLength = Math.min(maxChunkSize, length - paddingLength); // Wait until entire frame is readable if (buffer.readableBytes() < dataLength) { return; } ByteBuf data = buffer.readBytes(dataLength); length -= dataLength; if (length == paddingLength) { if (paddingLength == 0) { state = State.READ_FRAME_HEADER; } else { state = State.SKIP_FRAME_PADDING; } } endStream = length == paddingLength && hasFlag(flags, HTTP_FLAG_END_STREAM); endSegment = length == paddingLength && hasFlag(flags, HTTP_FLAG_END_SEGMENT); delegate.readDataFrame(streamId, endStream, endSegment, data); break; case READ_HEADERS_FRAME: minLength = 0; if (hasFlag(flags, HTTP_FLAG_PRIORITY)) { minLength = 5; } if (buffer.readableBytes() < minLength) { return; } endStream = hasFlag(flags, HTTP_FLAG_END_STREAM); endSegment = hasFlag(flags, HTTP_FLAG_END_SEGMENT); exclusive = false; dependency = 0; weight = 16; if (hasFlag(flags, HTTP_FLAG_PRIORITY)) { dependency = getSignedInt(buffer, buffer.readerIndex()); buffer.skipBytes(4); weight = buffer.readUnsignedByte() + 1; if (dependency < 0) { dependency = dependency & 0x7FFFFFFF; exclusive = true; } length -= 5; } state = State.READ_HEADER_BLOCK; delegate.readHeadersFrame(streamId, endStream, endSegment, exclusive, dependency, weight); break; case READ_PRIORITY_FRAME: if (buffer.readableBytes() < length) { return; } exclusive = false; dependency = getSignedInt(buffer, buffer.readerIndex()); buffer.skipBytes(4); weight = buffer.readUnsignedByte() + 1; if (dependency < 0) { dependency = dependency & 0x7FFFFFFF; exclusive = true; } state = State.READ_FRAME_HEADER; delegate.readPriorityFrame(streamId, exclusive, dependency, weight); break; case READ_RST_STREAM_FRAME: if (buffer.readableBytes() < length) { return; } errorCode = getSignedInt(buffer, buffer.readerIndex()); buffer.skipBytes(length); state = State.READ_FRAME_HEADER; delegate.readRstStreamFrame(streamId, errorCode); break; case READ_SETTINGS_FRAME: boolean ack = hasFlag(flags, HTTP_FLAG_ACK); state = State.READ_SETTING; delegate.readSettingsFrame(ack); break; case READ_SETTING: if (length == 0) { state = State.READ_FRAME_HEADER; delegate.readSettingsEnd(); break; } if (buffer.readableBytes() < 6) { return; } int id = getUnsignedShort(buffer, buffer.readerIndex()); int value = getSignedInt(buffer, buffer.readerIndex() + 2); buffer.skipBytes(6); length -= 6; delegate.readSetting(id, value); break; case READ_PUSH_PROMISE_FRAME: if (buffer.readableBytes() < 4) { return; } int promisedStreamId = getUnsignedInt(buffer, buffer.readerIndex()); buffer.skipBytes(4); length -= 4; if (promisedStreamId == 0) { state = State.FRAME_ERROR; delegate.readFrameError("Invalid Promised-Stream-ID"); } else { state = State.READ_HEADER_BLOCK; delegate.readPushPromiseFrame(streamId, promisedStreamId); } break; case READ_PING_FRAME: if (buffer.readableBytes() < length) { return; } long ping = getSignedLong(buffer, buffer.readerIndex()); buffer.skipBytes(length); boolean pong = hasFlag(flags, HTTP_FLAG_ACK); state = State.READ_FRAME_HEADER; delegate.readPingFrame(ping, pong); break; case READ_GOAWAY_FRAME: if (buffer.readableBytes() < 8) { return; } int lastStreamId = getUnsignedInt(buffer, buffer.readerIndex()); errorCode = getSignedInt(buffer, buffer.readerIndex() + 4); buffer.skipBytes(8); length -= 8; if (length == 0) { state = State.READ_FRAME_HEADER; } else { paddingLength = length; state = State.SKIP_FRAME_PADDING; } delegate.readGoAwayFrame(lastStreamId, errorCode); break; case READ_WINDOW_UPDATE_FRAME: // Wait until entire frame is readable if (buffer.readableBytes() < length) { return; } int windowSizeIncrement = getUnsignedInt(buffer, buffer.readerIndex()); buffer.skipBytes(length); if (windowSizeIncrement == 0) { state = State.FRAME_ERROR; delegate.readFrameError("Invalid Window Size Increment"); } else { state = State.READ_FRAME_HEADER; delegate.readWindowUpdateFrame(streamId, windowSizeIncrement); } break; case READ_CONTINUATION_FRAME_HEADER: // Wait until entire frame header is readable if (buffer.readableBytes() < HTTP_FRAME_HEADER_SIZE) { return; } // Read and validate continuation frame header fields int prevStreamId = streamId; readFrameHeader(buffer); // TODO(jpinner) Sec 4.2 FRAME_SIZE_ERROR // TODO(jpinner) invalid flags if (type != HTTP_CONTINUATION_FRAME || streamId != prevStreamId) { state = State.FRAME_ERROR; delegate.readFrameError("Invalid Continuation Frame"); } else { paddingLength = 0; state = State.READ_HEADER_BLOCK; } break; case READ_HEADER_BLOCK: if (length == paddingLength) { boolean endHeaders = hasFlag(flags, HTTP_FLAG_END_HEADERS); if (endHeaders) { state = State.SKIP_FRAME_PADDING; delegate.readHeaderBlockEnd(); } else { state = State.SKIP_FRAME_PADDING_CONTINUATION; } break; } if (!buffer.isReadable()) { return; } int readableBytes = Math.min(buffer.readableBytes(), length - paddingLength); ByteBuf headerBlockFragment = buffer.readBytes(readableBytes); length -= readableBytes; delegate.readHeaderBlock(headerBlockFragment); break; case SKIP_FRAME_PADDING: int numBytes = Math.min(buffer.readableBytes(), length); buffer.skipBytes(numBytes); length -= numBytes; if (length == 0) { state = State.READ_FRAME_HEADER; break; } return; case SKIP_FRAME_PADDING_CONTINUATION: int numPaddingBytes = Math.min(buffer.readableBytes(), length); buffer.skipBytes(numPaddingBytes); length -= numPaddingBytes; if (length == 0) { state = State.READ_CONTINUATION_FRAME_HEADER; break; } return; case FRAME_ERROR: buffer.skipBytes(buffer.readableBytes()); return; default: throw new Error("Shouldn't reach here."); } } } /** * Reads the HTTP/2 Frame Header and sets the length, type, flags, and streamId member variables. * * @param buffer input buffer containing the entire 9-octet header */ private void readFrameHeader(ByteBuf buffer) { int frameOffset = buffer.readerIndex(); length = getUnsignedMedium(buffer, frameOffset); type = buffer.getUnsignedByte(frameOffset + 3); flags = buffer.getByte(frameOffset + 4); streamId = getUnsignedInt(buffer, frameOffset + 5); buffer.skipBytes(HTTP_FRAME_HEADER_SIZE); } private static boolean hasFlag(byte flags, byte flag) { return (flags & flag) != 0; } private static boolean frameHasPadding(int type, byte flags) { switch (type) { case HTTP_DATA_FRAME: case HTTP_HEADERS_FRAME: case HTTP_PUSH_PROMISE_FRAME: return hasFlag(flags, HTTP_FLAG_PADDED); default: return false; } } private static State getNextState(int length, int type) { switch (type) { case HTTP_DATA_FRAME: return State.READ_DATA_FRAME; case HTTP_HEADERS_FRAME: return State.READ_HEADERS_FRAME; case HTTP_PRIORITY_FRAME: return State.READ_PRIORITY_FRAME; case HTTP_RST_STREAM_FRAME: return State.READ_RST_STREAM_FRAME; case HTTP_SETTINGS_FRAME: return State.READ_SETTINGS_FRAME; case HTTP_PUSH_PROMISE_FRAME: return State.READ_PUSH_PROMISE_FRAME; case HTTP_PING_FRAME: return State.READ_PING_FRAME; case HTTP_GOAWAY_FRAME: return State.READ_GOAWAY_FRAME; case HTTP_WINDOW_UPDATE_FRAME: return State.READ_WINDOW_UPDATE_FRAME; case HTTP_CONTINUATION_FRAME: throw new Error("Shouldn't reach here."); default: if (length != 0) { return State.SKIP_FRAME_PADDING; } else { return State.READ_FRAME_HEADER; } } } private static boolean isValidFrameHeader(int length, short type, byte flags, int streamId) { if (length > HTTP_MAX_LENGTH) { return false; } int minLength; switch (type) { case HTTP_DATA_FRAME: if (hasFlag(flags, HTTP_FLAG_PADDED)) { minLength = 1; } else { minLength = 0; } return length >= minLength && streamId != 0; case HTTP_HEADERS_FRAME: if (hasFlag(flags, HTTP_FLAG_PADDED)) { minLength = 1; } else { minLength = 0; } if (hasFlag(flags, HTTP_FLAG_PRIORITY)) { minLength += 5; } return length >= minLength && streamId != 0; case HTTP_PRIORITY_FRAME: return length == 5 && streamId != 0; case HTTP_RST_STREAM_FRAME: return length == 4 && streamId != 0; case HTTP_SETTINGS_FRAME: boolean lengthValid = hasFlag(flags, HTTP_FLAG_ACK) ? length == 0 : (length % 6) == 0; return lengthValid && streamId == 0; case HTTP_PUSH_PROMISE_FRAME: if (hasFlag(flags, HTTP_FLAG_PADDED)) { minLength = 5; } else { minLength = 4; } return length >= minLength && streamId != 0; case HTTP_PING_FRAME: return length == 8 && streamId == 0; case HTTP_GOAWAY_FRAME: return length >= 8 && streamId == 0; case HTTP_WINDOW_UPDATE_FRAME: return length == 4; case HTTP_CONTINUATION_FRAME: return false; default: return true; } } private static boolean isValidPaddingLength(int length, short type, byte flags, int paddingLength) { switch (type) { case HTTP_DATA_FRAME: return length >= paddingLength; case HTTP_HEADERS_FRAME: if (hasFlag(flags, HTTP_FLAG_PRIORITY)) { return length >= paddingLength + 5; } else { return length >= paddingLength; } case HTTP_PUSH_PROMISE_FRAME: return length >= paddingLength + 4; default: throw new Error("Shouldn't reach here."); } } }