io.crate.protocols.postgres.Messages.java Source code

Java tutorial

Introduction

Here is the source code for io.crate.protocols.postgres.Messages.java

Source

/*
 * Licensed to Crate under one or more contributor license agreements.
 * See the NOTICE file distributed with this work for additional
 * information regarding copyright ownership.  Crate licenses this file
 * to you under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.  You may
 * obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.  See the License for the specific language governing
 * permissions and limitations under the License.
 *
 * However, if you have executed another commercial license agreement
 * with Crate these terms will supersede the license and you may use the
 * software solely pursuant to the terms of the relevant commercial
 * agreement.
 */

package io.crate.protocols.postgres;

import io.crate.expression.symbol.Field;
import io.crate.data.Row;
import io.crate.exceptions.SQLExceptions;
import io.crate.protocols.postgres.types.PGType;
import io.crate.protocols.postgres.types.PGTypes;
import io.crate.types.DataType;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.logging.Loggers;

import javax.annotation.Nullable;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.SortedSet;

/**
 * Regular data packet is in the following format:
 * <p>
 * +----------+-----------+----------+
 * | char tag | int32 len | payload  |
 * +----------+-----------+----------+
 * <p>
 * The tag indicates the message type, the second field is the length of the packet
 * (excluding the tag, but including the length itself)
 * <p>
 * <p>
 * See https://www.postgresql.org/docs/9.2/static/protocol-message-formats.html
 */
public class Messages {

    private static final Logger LOGGER = Loggers.getLogger(Messages.class);

    private static final byte[] SEVERITY_FATAL = "FATAL".getBytes(StandardCharsets.UTF_8);
    private static final byte[] SEVERITY_ERROR = "ERROR".getBytes(StandardCharsets.UTF_8);
    private static final byte[] ERROR_CODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000"
            .getBytes(StandardCharsets.UTF_8);
    private static final byte[] ERROR_CODE_FEATURE_NOT_SUPPORTED = "0A000".getBytes(StandardCharsets.UTF_8);
    private static final byte[] ERROR_CODE_INTERNAL_ERROR = "XX000".getBytes(StandardCharsets.UTF_8);
    private static final byte[] METHOD_NAME_CLIENT_AUTH = "ClientAuthentication".getBytes(StandardCharsets.UTF_8);

    public static ChannelFuture sendAuthenticationOK(Channel channel) {
        ByteBuf buffer = channel.alloc().buffer(9);
        buffer.writeByte('R');
        buffer.writeInt(8); // size excluding char
        buffer.writeInt(0);
        ChannelFuture channelFuture = channel.writeAndFlush(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener((ChannelFutureListener) future -> LOGGER.trace("sentAuthenticationOK"));
        }
        return channelFuture;
    }

    /**
     * | 'C' | int32 len | str commandTag
     * @param query    :the query
     * @param rowCount : number of rows in the result set or number of rows affected by the DML statement
     */
    static ChannelFuture sendCommandComplete(Channel channel, String query, long rowCount) {
        query = query.trim().split(" ", 2)[0].toUpperCase(Locale.ENGLISH);
        String commandTag;
        /*
         * from https://www.postgresql.org/docs/current/static/protocol-message-formats.html:
         *
         * For an INSERT command, the tag is INSERT oid rows, where rows is the number of rows inserted.
         * oid is the object ID of the inserted row if rows is 1 and the target table has OIDs; otherwise oid is 0.
         */
        if ("BEGIN".equals(query)) {
            commandTag = "BEGIN";
        } else if ("INSERT".equals(query)) {
            commandTag = "INSERT 0 " + rowCount;
        } else {
            commandTag = query + " " + rowCount;
        }

        byte[] commandTagBytes = commandTag.getBytes(StandardCharsets.UTF_8);
        int length = 4 + commandTagBytes.length + 1;
        ByteBuf buffer = channel.alloc().buffer(length + 1);
        buffer.writeByte('C');
        buffer.writeInt(length);
        writeCString(buffer, commandTagBytes);
        ChannelFuture channelFuture = channel.writeAndFlush(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener((ChannelFutureListener) future -> LOGGER.trace("sentCommandComplete"));
        }
        return channelFuture;
    }

    /**
     * ReadyForQuery (B)
     * <p>
     * Byte1('Z')
     * Identifies the message type. ReadyForQuery is sent whenever the
     * backend is ready for a new query cycle.
     * <p>
     * Int32(5)
     * Length of message contents in bytes, including self.
     * <p>
     * Byte1
     * Current backend transaction status indicator. Possible values are
     * 'I' if idle (not in a transaction block); 'T' if in a transaction
     * block; or 'E' if in a failed transaction block (queries will be
     * rejected until block is ended).
     */
    static void sendReadyForQuery(Channel channel) {
        ByteBuf buffer = channel.alloc().buffer(6);
        buffer.writeByte('Z');
        buffer.writeInt(5);
        buffer.writeByte('I');
        ChannelFuture channelFuture = channel.writeAndFlush(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener((ChannelFutureListener) future -> LOGGER.trace("sentReadyForQuery"));
        }
    }

    /**
     * | 'S' | int32 len | str name | str value
     * <p>
     * See https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC
     * <p>
     * > At present there is a hard-wired set of parameters for which ParameterStatus will be generated: they are
     * <p>
     * - server_version,
     * - server_encoding,
     * - client_encoding,
     * - application_name,
     * - is_superuser,
     * - session_authorization,
     * - DateStyle,
     * - IntervalStyle,
     * - TimeZone,
     * - integer_datetimes,
     * - standard_conforming_string
     */
    static void sendParameterStatus(Channel channel, final String name, final String value) {
        byte[] nameBytes = name.getBytes(StandardCharsets.UTF_8);
        byte[] valueBytes = value.getBytes(StandardCharsets.UTF_8);

        int length = 4 + nameBytes.length + 1 + valueBytes.length + 1;
        ByteBuf buffer = channel.alloc().buffer(length + 1);
        buffer.writeByte('S');
        buffer.writeInt(length);
        writeCString(buffer, nameBytes);
        writeCString(buffer, valueBytes);
        ChannelFuture channelFuture = channel.write(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener(
                    (ChannelFutureListener) future -> LOGGER.trace("sentParameterStatus {}={}", name, value));
        }
    }

    static void sendAuthenticationError(Channel channel, String message) {
        LOGGER.warn(message);
        byte[] msg = message.getBytes(StandardCharsets.UTF_8);
        sendErrorResponse(channel, message, msg, SEVERITY_FATAL, null, null, METHOD_NAME_CLIENT_AUTH,
                ERROR_CODE_INVALID_AUTHORIZATION_SPECIFICATION);
    }

    static ChannelFuture sendErrorResponse(Channel channel, Throwable throwable) {
        final String message = SQLExceptions.messageOf(throwable);
        byte[] msg = message.getBytes(StandardCharsets.UTF_8);
        byte[] lineNumber = null;
        byte[] fileName = null;
        byte[] methodName = null;

        StackTraceElement[] stackTrace = throwable.getStackTrace();
        if (stackTrace != null && stackTrace.length > 0) {
            StackTraceElement stackTraceElement = stackTrace[0];
            lineNumber = String.valueOf(stackTraceElement.getLineNumber()).getBytes(StandardCharsets.UTF_8);
            if (stackTraceElement.getFileName() != null) {
                fileName = stackTraceElement.getFileName().getBytes(StandardCharsets.UTF_8);
            }
            if (stackTraceElement.getMethodName() != null) {
                methodName = stackTraceElement.getMethodName().getBytes(StandardCharsets.UTF_8);
            }
        }

        // See https://www.postgresql.org/docs/9.2/static/errcodes-appendix.html
        // need to add a throwable -> error code mapping later on
        byte[] errorCode;
        if (throwable instanceof IllegalArgumentException || throwable instanceof UnsupportedOperationException) {
            // feature_not_supported
            errorCode = ERROR_CODE_FEATURE_NOT_SUPPORTED;
        } else {
            // internal_error
            errorCode = ERROR_CODE_INTERNAL_ERROR;
        }
        return sendErrorResponse(channel, message, msg, SEVERITY_ERROR, lineNumber, fileName, methodName,
                errorCode);
    }

    /**
     * 'E' | int32 len | char code | str value | \0 | char code | str value | \0 | ... | \0
     * <p>
     * char code / str value -> key-value fields
     * example error fields are: message, detail, hint, error position
     * <p>
     * See https://www.postgresql.org/docs/9.2/static/protocol-error-fields.html for a list of error codes
     */
    private static ChannelFuture sendErrorResponse(Channel channel, String message, byte[] msg, byte[] severity,
            byte[] lineNumber, byte[] fileName, byte[] methodName, byte[] errorCode) {
        int length = 4 + 1 + (severity.length + 1) + 1 + (msg.length + 1) + 1 + (errorCode.length + 1)
                + (fileName != null ? 1 + (fileName.length + 1) : 0)
                + (lineNumber != null ? 1 + (lineNumber.length + 1) : 0)
                + (methodName != null ? 1 + (methodName.length + 1) : 0) + 1;
        ByteBuf buffer = channel.alloc().buffer(length + 1);
        buffer.writeByte('E');
        buffer.writeInt(length);
        buffer.writeByte('S');
        writeCString(buffer, severity);
        buffer.writeByte('M');
        writeCString(buffer, msg);
        buffer.writeByte(('C'));
        writeCString(buffer, errorCode);
        if (fileName != null) {
            buffer.writeByte('F');
            writeCString(buffer, fileName);
        }
        if (lineNumber != null) {
            buffer.writeByte('L');
            writeCString(buffer, lineNumber);
        }
        if (methodName != null) {
            buffer.writeByte('R');
            writeCString(buffer, methodName);
        }
        buffer.writeByte(0);
        ChannelFuture channelFuture = channel.writeAndFlush(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener(
                    (ChannelFutureListener) future -> LOGGER.trace("sentErrorResponse msg={}", message));
        }
        return channelFuture;
    }

    /**
     * Byte1('D')
     * Identifies the message as a data row.
     * <p>
     * Int32
     * Length of message contents in bytes, including self.
     * <p>
     * Int16
     * The number of column values that follow (possibly zero).
     * <p>
     * Next, the following pair of fields appear for each column:
     * <p>
     * Int32
     * The length of the column value, in bytes (this count does not include itself).
     * Can be zero. As a special case, -1 indicates a NULL column value. No value bytes follow in the NULL case.
     * <p>
     * ByteN
     * The value of the column, in the format indicated by the associated format code. n is the above length.
     */
    static void sendDataRow(Channel channel, Row row, List<? extends DataType> columnTypes,
            @Nullable FormatCodes.FormatCode[] formatCodes) {
        int length = 4 + 2;
        assert columnTypes.size() == row
                .numColumns() : "Number of columns in the row must match number of columnTypes. Row: " + row
                        + " types: " + columnTypes;

        ByteBuf buffer = channel.alloc().buffer();
        buffer.writeByte('D');
        buffer.writeInt(0); // will be set at the end
        buffer.writeShort(row.numColumns());

        for (int i = 0; i < row.numColumns(); i++) {
            DataType dataType = columnTypes.get(i);
            PGType pgType = PGTypes.get(dataType);
            Object value = row.get(i);
            if (value == null) {
                buffer.writeInt(-1);
                length += 4;
            } else {
                FormatCodes.FormatCode formatCode = FormatCodes.getFormatCode(formatCodes, i);
                switch (formatCode) {
                case TEXT:
                    length += pgType.writeAsText(buffer, value);
                    break;
                case BINARY:
                    length += pgType.writeAsBinary(buffer, value);
                    break;

                default:
                    throw new AssertionError("Unrecognized formatCode: " + formatCode);
                }
            }
        }

        buffer.setInt(1, length);
        channel.write(buffer);
    }

    static void writeCString(ByteBuf buffer, byte[] valBytes) {
        buffer.writeBytes(valBytes);
        buffer.writeByte(0);
    }

    /**
     * ParameterDescription (B)
     *
     *     Byte1('t')
     *
     *         Identifies the message as a parameter description.
     *     Int32
     *
     *         Length of message contents in bytes, including self.
     *     Int16
     *
     *         The number of parameters used by the statement (can be zero).
     *
     *     Then, for each parameter, there is the following:
     *
     *     Int32
     *
     *         Specifies the object ID of the parameter data type.
     *
     * @param channel The channel to write the parameter description to.
     * @param parameters A {@link SortedSet} containing the parameters from index 1 upwards.
     */
    static void sendParameterDescription(Channel channel, DataType[] parameters) {
        final int messageByteSize = 4 + 2 + parameters.length * 4;
        ByteBuf buffer = channel.alloc().buffer(messageByteSize);
        buffer.writeByte('t');
        buffer.writeInt(messageByteSize);
        if (parameters.length > Short.MAX_VALUE) {
            throw new IllegalArgumentException("Too many parameters. Max supported: " + Short.MAX_VALUE);
        }
        buffer.writeShort(parameters.length);
        for (DataType dataType : parameters) {
            int pgTypeId = PGTypes.get(dataType).oid();
            buffer.writeInt(pgTypeId);
        }
        channel.write(buffer);
    }

    /**
     * RowDescription (B)
     * <p>
     * | 'T' | int32 len | int16 numCols
     * <p>
     * For each field:
     * <p>
     * | string name | int32 table_oid | int16 attr_num | int32 oid | int16 typlen | int32 type_modifier | int16 format_code
     * <p>
     * See https://www.postgresql.org/docs/current/static/protocol-message-formats.html
     */
    static void sendRowDescription(Channel channel, Collection<Field> columns,
            @Nullable FormatCodes.FormatCode[] formatCodes) {
        int length = 4 + 2;
        int columnSize = 4 + 2 + 4 + 2 + 4 + 2;
        ByteBuf buffer = channel.alloc().buffer(length + (columns.size() * (10 + columnSize))); // use 10 as an estimate for columnName length

        buffer.writeByte('T');
        buffer.writeInt(0); // will be set at the end
        buffer.writeShort(columns.size());

        int idx = 0;
        for (Field column : columns) {
            byte[] nameBytes = column.path().outputName().getBytes(StandardCharsets.UTF_8);
            length += nameBytes.length + 1;
            length += columnSize;

            writeCString(buffer, nameBytes);
            buffer.writeInt(0); // table_oid
            buffer.writeShort(0); // attr_num

            PGType pgType = PGTypes.get(column.valueType());
            buffer.writeInt(pgType.oid());
            buffer.writeShort(pgType.typeLen());
            buffer.writeInt(pgType.typeMod());
            buffer.writeShort(FormatCodes.getFormatCode(formatCodes, idx).ordinal());

            idx++;
        }

        buffer.setInt(1, length);
        ChannelFuture channelFuture = channel.writeAndFlush(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener((ChannelFutureListener) future -> LOGGER.trace("sentRowDescription"));
        }
    }

    /**
     * ParseComplete
     * | '1' | int32 len |
     */
    static void sendParseComplete(Channel channel) {
        sendShortMsg(channel, '1', "sentParseComplete");
    }

    /**
     * BindComplete
     * | '2' | int32 len |
     */
    static void sendBindComplete(Channel channel) {
        sendShortMsg(channel, '2', "sentBindComplete");
    }

    /**
     * EmptyQueryResponse
     * | 'I' | int32 len |
     */
    static void sendEmptyQueryResponse(Channel channel) {
        sendShortMsg(channel, 'I', "sentEmptyQueryResponse");
    }

    /**
     * NoData
     * | 'n' | int32 len |
     */
    static void sendNoData(Channel channel) {
        sendShortMsg(channel, 'n', "sentNoData");
    }

    /**
     * Send a message that just contains the msgType and the msg length
     */
    private static void sendShortMsg(Channel channel, char msgType, final String traceLogMsg) {
        ByteBuf buffer = channel.alloc().buffer(5);
        buffer.writeByte(msgType);
        buffer.writeInt(4);

        ChannelFuture channelFuture = channel.write(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener((ChannelFutureListener) future -> LOGGER.trace(traceLogMsg));
        }
    }

    static void sendPortalSuspended(Channel channel) {
        sendShortMsg(channel, 's', "sentPortalSuspended");
    }

    /**
     * CloseComplete
     * | '3' | int32 len |
     */
    static void sendCloseComplete(Channel channel) {
        sendShortMsg(channel, '3', "sentCloseComplete");
    }

    /**
     * AuthenticationCleartextPassword (B)
     *
     * Byte1('R')
     * Identifies the message as an authentication request.
     *
     * Int32(8)
     * Length of message contents in bytes, including self.
     *
     * Int32(3)
     * Specifies that a clear-text password is required.
     *
     * @param channel The channel to write to.
     */
    static void sendAuthenticationCleartextPassword(Channel channel) {
        ByteBuf buffer = channel.alloc().buffer(9);
        buffer.writeByte('R');
        buffer.writeInt(8);
        buffer.writeInt(3);
        ChannelFuture channelFuture = channel.writeAndFlush(buffer);
        if (LOGGER.isTraceEnabled()) {
            channelFuture.addListener(
                    (ChannelFutureListener) future -> LOGGER.trace("sentAuthenticationCleartextPassword"));
        }
    }
}