org.eclipse.milo.opcua.stack.server.transport.uasc.UascServerSymmetricHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.eclipse.milo.opcua.stack.server.transport.uasc.UascServerSymmetricHandler.java

Source

/*
 * Copyright (c) 2019 the Eclipse Milo Authors
 *
 * This program and the accompanying materials are made
 * available under the terms of the Eclipse Public License 2.0
 * which is available at https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 */

package org.eclipse.milo.opcua.stack.server.transport.uasc;

import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.ReferenceCountUtil;
import org.eclipse.milo.opcua.stack.core.StatusCodes;
import org.eclipse.milo.opcua.stack.core.UaException;
import org.eclipse.milo.opcua.stack.core.UaSerializationException;
import org.eclipse.milo.opcua.stack.core.channel.ChannelSecurity;
import org.eclipse.milo.opcua.stack.core.channel.ChunkDecoder;
import org.eclipse.milo.opcua.stack.core.channel.ChunkEncoder;
import org.eclipse.milo.opcua.stack.core.channel.MessageAbortedException;
import org.eclipse.milo.opcua.stack.core.channel.SerializationQueue;
import org.eclipse.milo.opcua.stack.core.channel.ServerSecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.headers.HeaderDecoder;
import org.eclipse.milo.opcua.stack.core.channel.messages.MessageType;
import org.eclipse.milo.opcua.stack.core.serialization.UaRequestMessage;
import org.eclipse.milo.opcua.stack.core.serialization.UaResponseMessage;
import org.eclipse.milo.opcua.stack.core.types.builtin.DateTime;
import org.eclipse.milo.opcua.stack.core.types.builtin.StatusCode;
import org.eclipse.milo.opcua.stack.core.types.builtin.unsigned.UInteger;
import org.eclipse.milo.opcua.stack.core.types.structured.EndpointDescription;
import org.eclipse.milo.opcua.stack.core.types.structured.ResponseHeader;
import org.eclipse.milo.opcua.stack.core.types.structured.ServiceFault;
import org.eclipse.milo.opcua.stack.core.util.BufferUtil;
import org.eclipse.milo.opcua.stack.core.util.EndpointUtil;
import org.eclipse.milo.opcua.stack.server.UaStackServer;
import org.eclipse.milo.opcua.stack.server.services.ServiceRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.eclipse.milo.opcua.stack.core.types.builtin.unsigned.Unsigned.uint;

public class UascServerSymmetricHandler extends ByteToMessageDecoder implements HeaderDecoder {

    private final Logger logger = LoggerFactory.getLogger(getClass());

    private List<ByteBuf> chunkBuffers;

    private final int maxChunkCount;
    private final int maxChunkSize;

    private final UaStackServer stackServer;
    private final SerializationQueue serializationQueue;
    private final ServerSecureChannel secureChannel;

    UascServerSymmetricHandler(UaStackServer stackServer, SerializationQueue serializationQueue,
            ServerSecureChannel secureChannel) {

        this.stackServer = stackServer;
        this.serializationQueue = serializationQueue;
        this.secureChannel = secureChannel;

        maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
        maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();

        chunkBuffers = new ArrayList<>(maxChunkCount);
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
        while (buffer.readableBytes() >= HEADER_LENGTH) {
            int messageLength = getMessageLength(buffer, maxChunkSize);

            if (buffer.readableBytes() < messageLength) {
                break;
            }

            MessageType messageType = MessageType.fromMediumInt(buffer.getMediumLE(buffer.readerIndex()));

            switch (messageType) {
            case SecureMessage:
                onSecureMessage(ctx, buffer.readSlice(messageLength));
                break;

            default:
                out.add(buffer.readSlice(messageLength).retain());
            }
        }
    }

    private void onSecureMessage(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
        buffer.skipBytes(3); // Skip messageType

        char chunkType = (char) buffer.readByte();

        if (chunkType == 'A') {
            chunkBuffers.forEach(ByteBuf::release);
            chunkBuffers.clear();
        } else {
            buffer.skipBytes(4); // Skip messageSize

            long secureChannelId = buffer.readUnsignedIntLE();
            if (secureChannelId != secureChannel.getChannelId()) {
                throw new UaException(StatusCodes.Bad_SecureChannelIdInvalid,
                        "invalid secure channel id: " + secureChannelId);
            }

            int chunkSize = buffer.readerIndex(0).readableBytes();
            if (chunkSize > maxChunkSize) {
                throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
                        String.format("max chunk size exceeded (%s)", maxChunkSize));
            }

            chunkBuffers.add(buffer.retain());

            if (maxChunkCount > 0 && chunkBuffers.size() > maxChunkCount) {
                throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
                        String.format("max chunk count exceeded (%s)", maxChunkCount));
            }

            if (chunkType == 'F') {
                final List<ByteBuf> buffersToDecode = chunkBuffers;
                chunkBuffers = new ArrayList<>();

                serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
                    try {
                        validateChunkHeaders(buffersToDecode);
                    } catch (UaException e) {
                        logger.error("Error validating chunk headers: {}", e.getMessage(), e);
                        buffersToDecode.forEach(ReferenceCountUtil::safeRelease);
                        ctx.fireExceptionCaught(e);
                        return;
                    }

                    chunkDecoder.decodeSymmetric(secureChannel, buffersToDecode, new ChunkDecoder.Callback() {
                        @Override
                        public void onDecodingError(UaException ex) {
                            logger.error("Error decoding symmetric message: {}", ex.getMessage(), ex);

                            ctx.close();
                        }

                        @Override
                        public void onMessageAborted(MessageAbortedException ex) {
                            logger.warn("Received message abort chunk; error={}, reason={}", ex.getStatusCode(),
                                    ex.getMessage());
                        }

                        @Override
                        public void onMessageDecoded(ByteBuf message, long requestId) {
                            UaRequestMessage request = (UaRequestMessage) binaryDecoder.setBuffer(message)
                                    .readMessage(null);

                            stackServer.getConfig().getExecutor().execute(() -> {
                                try {
                                    String endpointUrl = ctx.channel().attr(UascServerHelloHandler.ENDPOINT_URL_KEY)
                                            .get();

                                    EndpointDescription endpoint = ctx.channel()
                                            .attr(UascServerAsymmetricHandler.ENDPOINT_KEY).get();

                                    String path = EndpointUtil.getPath(endpointUrl);

                                    InetSocketAddress remoteSocketAddress = (InetSocketAddress) ctx.channel()
                                            .remoteAddress();

                                    ServiceRequest serviceRequest = new ServiceRequest(stackServer, request,
                                            endpoint, secureChannel.getChannelId(),
                                            remoteSocketAddress.getAddress(),
                                            secureChannel.getRemoteCertificateBytes());

                                    serviceRequest.getFuture().whenComplete((response, fault) -> {
                                        if (response != null) {
                                            sendServiceResponse(ctx, requestId, request, response);
                                        } else {
                                            UInteger requestHandle = request.getRequestHeader().getRequestHandle();

                                            sendServiceFault(ctx, requestId, requestHandle, fault);
                                        }
                                    });

                                    stackServer.onServiceRequest(path, serviceRequest);
                                } catch (Throwable t) {
                                    logger.error("Error decoding UaRequestMessage", t);

                                    sendServiceFault(ctx, requestId, uint(0), t);
                                } finally {
                                    message.release();
                                    buffersToDecode.clear();
                                }
                            });
                        }
                    });
                });
            }
        }
    }

    private void sendServiceResponse(ChannelHandlerContext ctx, long requestId, UaRequestMessage request,
            UaResponseMessage response) {

        serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
            ByteBuf messageBuffer = BufferUtil.pooledBuffer();

            try {
                binaryEncoder.setBuffer(messageBuffer);
                binaryEncoder.writeMessage(null, response);

                checkMessageSize(messageBuffer);

                chunkEncoder.encodeSymmetric(secureChannel, requestId, messageBuffer, MessageType.SecureMessage,
                        new ChunkEncoder.Callback() {
                            @Override
                            public void onEncodingError(UaException ex) {
                                logger.error("Error encoding {}: {}", response, ex.getMessage(), ex);

                                UInteger requestHandle = request.getRequestHeader().getRequestHandle();

                                sendServiceFault(ctx, requestId, requestHandle, ex);
                            }

                            @Override
                            public void onMessageEncoded(List<ByteBuf> messageChunks, long requestId) {
                                CompositeByteBuf chunkComposite = BufferUtil.compositeBuffer();

                                for (ByteBuf chunk : messageChunks) {
                                    chunkComposite.addComponent(chunk);
                                    chunkComposite
                                            .writerIndex(chunkComposite.writerIndex() + chunk.readableBytes());
                                }

                                ctx.writeAndFlush(chunkComposite, ctx.voidPromise());
                            }
                        });
            } catch (UaSerializationException ex) {
                logger.error("Error encoding response: {}", ex.getStatusCode(), ex);

                UInteger requestHandle = request.getRequestHeader().getRequestHandle();

                sendServiceFault(ctx, requestId, requestHandle, ex);
            } finally {
                messageBuffer.release();
            }
        });
    }

    private void sendServiceFault(ChannelHandlerContext ctx, long requestId, UInteger requestHandle,
            Throwable fault) {

        StatusCode statusCode = UaException.extract(fault).map(UaException::getStatusCode).orElse(StatusCode.BAD);

        ServiceFault serviceFault = new ServiceFault(
                new ResponseHeader(DateTime.now(), requestHandle, statusCode, null, null, null));

        serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
            ByteBuf messageBuffer = BufferUtil.pooledBuffer();

            try {
                binaryEncoder.setBuffer(messageBuffer);
                binaryEncoder.writeMessage(null, serviceFault);

                checkMessageSize(messageBuffer);

                chunkEncoder.encodeSymmetric(secureChannel, requestId, messageBuffer, MessageType.SecureMessage,
                        new ChunkEncoder.Callback() {
                            @Override
                            public void onEncodingError(UaException ex) {
                                logger.error("Error encoding {}: {}", serviceFault, ex.getMessage(), ex);
                            }

                            @Override
                            public void onMessageEncoded(List<ByteBuf> messageChunks, long requestId) {
                                CompositeByteBuf chunkComposite = BufferUtil.compositeBuffer();

                                for (ByteBuf chunk : messageChunks) {
                                    chunkComposite.addComponent(chunk);
                                    chunkComposite
                                            .writerIndex(chunkComposite.writerIndex() + chunk.readableBytes());
                                }

                                ctx.writeAndFlush(chunkComposite, ctx.voidPromise());
                            }
                        });
            } catch (UaSerializationException ex) {
                logger.error("Error encoding ServiceFault: {}", ex.getStatusCode(), ex);
            } finally {
                messageBuffer.release();
            }
        });
    }

    private void checkMessageSize(ByteBuf messageBuffer) throws UaSerializationException {
        int messageSize = messageBuffer.readableBytes();
        int remoteMaxMessageSize = serializationQueue.getParameters().getRemoteMaxMessageSize();

        if (remoteMaxMessageSize > 0 && messageSize > remoteMaxMessageSize) {
            throw new UaSerializationException(StatusCodes.Bad_ResponseTooLarge,
                    "response exceeds remote max message size: " + messageSize + " > " + remoteMaxMessageSize);
        }
    }

    private void validateChunkHeaders(List<ByteBuf> chunkBuffers) throws UaException {
        ChannelSecurity channelSecurity = secureChannel.getChannelSecurity();
        long currentTokenId = channelSecurity.getCurrentToken().getTokenId().longValue();
        long previousTokenId = channelSecurity.getPreviousToken().map(t -> t.getTokenId().longValue()).orElse(-1L);

        for (ByteBuf chunkBuffer : chunkBuffers) {
            // tokenId starts after messageType + chunkType + messageSize + secureChannelId
            long tokenId = chunkBuffer.getUnsignedIntLE(3 + 1 + 4 + 4);

            if (tokenId != currentTokenId && tokenId != previousTokenId) {
                String message = String.format(
                        "received unknown secure channel token: "
                                + "tokenId=%s currentTokenId=%s previousTokenId=%s",
                        tokenId, currentTokenId, previousTokenId);

                throw new UaException(StatusCodes.Bad_SecureChannelTokenUnknown, message);
            }
        }
    }

}