com.digitalpetri.opcua.stack.client.handlers.UaTcpClientMessageHandler.java Source code

Java tutorial

Introduction

Here is the source code for com.digitalpetri.opcua.stack.client.handlers.UaTcpClientMessageHandler.java

Source

/*
 * Copyright 2016 Kevin Herron
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.digitalpetri.opcua.stack.client.handlers;

import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import com.digitalpetri.opcua.stack.client.UaTcpStackClient;
import com.digitalpetri.opcua.stack.core.StatusCodes;
import com.digitalpetri.opcua.stack.core.UaException;
import com.digitalpetri.opcua.stack.core.UaRuntimeException;
import com.digitalpetri.opcua.stack.core.UaServiceFaultException;
import com.digitalpetri.opcua.stack.core.channel.ChannelSecurity;
import com.digitalpetri.opcua.stack.core.channel.ClientSecureChannel;
import com.digitalpetri.opcua.stack.core.channel.MessageAbortedException;
import com.digitalpetri.opcua.stack.core.channel.SerializationQueue;
import com.digitalpetri.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import com.digitalpetri.opcua.stack.core.channel.headers.HeaderDecoder;
import com.digitalpetri.opcua.stack.core.channel.messages.ErrorMessage;
import com.digitalpetri.opcua.stack.core.channel.messages.MessageType;
import com.digitalpetri.opcua.stack.core.channel.messages.TcpMessageDecoder;
import com.digitalpetri.opcua.stack.core.security.SecurityAlgorithm;
import com.digitalpetri.opcua.stack.core.serialization.UaRequestMessage;
import com.digitalpetri.opcua.stack.core.serialization.UaResponseMessage;
import com.digitalpetri.opcua.stack.core.types.builtin.ByteString;
import com.digitalpetri.opcua.stack.core.types.builtin.DateTime;
import com.digitalpetri.opcua.stack.core.types.builtin.StatusCode;
import com.digitalpetri.opcua.stack.core.types.builtin.unsigned.UInteger;
import com.digitalpetri.opcua.stack.core.types.enumerated.SecurityTokenRequestType;
import com.digitalpetri.opcua.stack.core.types.structured.ChannelSecurityToken;
import com.digitalpetri.opcua.stack.core.types.structured.CloseSecureChannelRequest;
import com.digitalpetri.opcua.stack.core.types.structured.OpenSecureChannelRequest;
import com.digitalpetri.opcua.stack.core.types.structured.OpenSecureChannelResponse;
import com.digitalpetri.opcua.stack.core.types.structured.RequestHeader;
import com.digitalpetri.opcua.stack.core.types.structured.ServiceFault;
import com.digitalpetri.opcua.stack.core.util.BufferUtil;
import com.digitalpetri.opcua.stack.core.util.LongSequence;
import com.digitalpetri.opcua.stack.core.util.NonceUtil;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.util.AttributeKey;
import io.netty.util.Timeout;
import org.jooq.lambda.tuple.Tuple2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static com.digitalpetri.opcua.stack.core.types.builtin.unsigned.Unsigned.uint;

public class UaTcpClientMessageHandler extends ByteToMessageCodec<UaRequestFuture> implements HeaderDecoder {

    public static final AttributeKey<Map<Long, UaRequestFuture>> KEY_PENDING_REQUEST_FUTURES = AttributeKey
            .valueOf("pending-request-futures");

    private final Logger logger = LoggerFactory.getLogger(getClass());

    private List<ByteBuf> chunkBuffers = new LinkedList<>();

    private final AtomicReference<AsymmetricSecurityHeader> headerRef = new AtomicReference<>();

    private ScheduledFuture renewFuture;
    private Timeout secureChannelTimeout;

    private final Map<Long, UaRequestFuture> pending;
    private final LongSequence requestIdSequence;

    private final UaTcpStackClient client;
    private final ClientSecureChannel secureChannel;
    private final SerializationQueue serializationQueue;
    private final CompletableFuture<ClientSecureChannel> handshakeFuture;

    public UaTcpClientMessageHandler(UaTcpStackClient client, ClientSecureChannel secureChannel,
            SerializationQueue serializationQueue, CompletableFuture<ClientSecureChannel> handshakeFuture) {

        this.client = client;
        this.secureChannel = secureChannel;
        this.serializationQueue = serializationQueue;
        this.handshakeFuture = handshakeFuture;

        secureChannel.attr(KEY_PENDING_REQUEST_FUTURES).setIfAbsent(Maps.newConcurrentMap());

        pending = secureChannel.attr(KEY_PENDING_REQUEST_FUTURES).get();

        secureChannel.attr(ClientSecureChannel.KEY_REQUEST_ID_SEQUENCE)
                .setIfAbsent(new LongSequence(1L, UInteger.MAX_VALUE));

        requestIdSequence = secureChannel.attr(ClientSecureChannel.KEY_REQUEST_ID_SEQUENCE).get();

        handshakeFuture.thenAccept(sc -> {
            Channel channel = sc.getChannel();

            channel.eventLoop().execute(() -> {
                List<UaRequestFuture> awaitingHandshake = channel
                        .attr(UaTcpClientAcknowledgeHandler.KEY_AWAITING_HANDSHAKE).get();

                if (awaitingHandshake != null) {
                    channel.attr(UaTcpClientAcknowledgeHandler.KEY_AWAITING_HANDSHAKE).remove();

                    logger.debug("{} message(s) queued before handshake completed; sending now.",
                            awaitingHandshake.size());
                    awaitingHandshake.forEach(channel::writeAndFlush);
                }
            });
        });
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        SecurityTokenRequestType requestType = secureChannel.getChannelId() == 0 ? SecurityTokenRequestType.Issue
                : SecurityTokenRequestType.Renew;

        secureChannelTimeout = client.getConfig().getWheelTimer().newTimeout(timeout -> {
            if (!timeout.isCancelled()) {
                handshakeFuture.completeExceptionally(
                        new UaException(StatusCodes.Bad_Timeout, "timed out waiting for secure channel"));
                ctx.close();
            }
        }, 5, TimeUnit.SECONDS);

        logger.debug("OpenSecureChannel timeout scheduled for +5s");

        sendOpenSecureChannelRequest(ctx, requestType);
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        if (renewFuture != null)
            renewFuture.cancel(false);

        handshakeFuture
                .completeExceptionally(new UaException(StatusCodes.Bad_ConnectionClosed, "connection closed"));

        super.channelInactive(ctx);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        logger.error("Exception caught: {}", cause.getMessage(), cause);
        ctx.close();
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof CloseSecureChannelRequest) {
            sendCloseSecureChannelRequest(ctx, (CloseSecureChannelRequest) evt);
        }
    }

    private void sendOpenSecureChannelRequest(ChannelHandlerContext ctx, SecurityTokenRequestType requestType) {
        SecurityAlgorithm algorithm = secureChannel.getSecurityPolicy().getSymmetricEncryptionAlgorithm();
        int nonceLength = NonceUtil.getNonceLength(algorithm);

        ByteString clientNonce = secureChannel.isSymmetricSigningEnabled() ? NonceUtil.generateNonce(nonceLength)
                : ByteString.NULL_VALUE;

        secureChannel.setLocalNonce(clientNonce);

        OpenSecureChannelRequest request = new OpenSecureChannelRequest(
                new RequestHeader(null, DateTime.now(), uint(0), uint(0), null, uint(0), null),
                uint(PROTOCOL_VERSION), requestType, secureChannel.getMessageSecurityMode(),
                secureChannel.getLocalNonce(), client.getChannelLifetime());

        encodeMessage(request, MessageType.OpenSecureChannel).whenComplete((t2, ex) -> {
            if (ex != null) {
                ctx.close();
                return;
            }

            List<ByteBuf> chunks = t2.v2();

            ctx.executor().execute(() -> {
                chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
                ctx.flush();
            });

            ChannelSecurity channelSecurity = secureChannel.getChannelSecurity();

            long currentTokenId = channelSecurity != null
                    ? channelSecurity.getCurrentToken().getTokenId().longValue()
                    : -1L;

            long previousTokenId = channelSecurity != null
                    ? channelSecurity.getPreviousToken().map(token -> token.getTokenId().longValue()).orElse(-1L)
                    : -1L;

            logger.debug("Sent OpenSecureChannelRequest ({}, id={}, currentToken={}, previousToken={}).",
                    request.getRequestType(), secureChannel.getChannelId(), currentTokenId, previousTokenId);
        });
    }

    private void sendCloseSecureChannelRequest(ChannelHandlerContext ctx, CloseSecureChannelRequest request) {
        encodeMessage(request, MessageType.CloseSecureChannel).whenComplete((t2, ex) -> {
            if (ex != null) {
                ctx.close();
                return;
            }

            List<ByteBuf> chunks = t2.v2();

            ctx.executor().execute(() -> {
                chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
                ctx.flush();
                ctx.close();
            });

            secureChannel.setChannelId(0);
        });
    }

    @Override
    protected void encode(ChannelHandlerContext ctx, UaRequestFuture request, ByteBuf buffer) throws Exception {
        encodeMessage(request.getRequest(), MessageType.SecureMessage).whenComplete((t2, ex) -> {
            if (ex != null) {
                ctx.close();
                return;
            }

            long requestId = t2.v1();
            List<ByteBuf> chunks = t2.v2();

            pending.put(requestId, request);

            // No matter how we complete, make sure the entry in pending is removed.
            // This covers the case where the request fails due to a timeout in the
            // upper layers as well as normal completion.
            request.getFuture().whenComplete((r, x) -> pending.remove(requestId));

            ctx.executor().execute(() -> {
                chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
                ctx.flush();
            });
        });
    }

    private CompletableFuture<Tuple2<Long, List<ByteBuf>>> encodeMessage(UaRequestMessage request,
            MessageType messageType) {

        CompletableFuture<Tuple2<Long, List<ByteBuf>>> future = new CompletableFuture<>();

        serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
            ByteBuf messageBuffer = null;

            try {
                messageBuffer = BufferUtil.buffer();
                binaryEncoder.setBuffer(messageBuffer);
                binaryEncoder.encodeMessage(null, request);

                List<ByteBuf> chunks;

                if (messageType == MessageType.OpenSecureChannel) {
                    chunks = chunkEncoder.encodeAsymmetric(secureChannel, messageType, messageBuffer,
                            requestIdSequence.getAndIncrement());
                } else {
                    chunks = chunkEncoder.encodeSymmetric(secureChannel, messageType, messageBuffer,
                            requestIdSequence.getAndIncrement());
                }

                future.complete(new Tuple2<>(chunkEncoder.getLastRequestId(), chunks));
            } catch (UaException ex) {
                logger.error("Error encoding {}: {}", request, ex.getMessage(), ex);

                future.completeExceptionally(ex);
            } finally {
                if (messageBuffer != null) {
                    messageBuffer.release();
                }
            }
        });

        return future;
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
        if (buffer.readableBytes() >= HEADER_LENGTH) {
            buffer = buffer.order(ByteOrder.LITTLE_ENDIAN);

            if (buffer.readableBytes() >= getMessageLength(buffer)) {
                decodeMessage(ctx, buffer);
            }
        }
    }

    private void decodeMessage(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
        int messageLength = getMessageLength(buffer);
        MessageType messageType = MessageType.fromMediumInt(buffer.getMedium(buffer.readerIndex()));

        switch (messageType) {
        case OpenSecureChannel:
            onOpenSecureChannel(ctx, buffer.readSlice(messageLength));
            break;

        case SecureMessage:
            onSecureMessage(ctx, buffer.readSlice(messageLength));
            break;

        case Error:
            onError(ctx, buffer.readSlice(messageLength));
            break;

        default:
            throw new UaException(StatusCodes.Bad_TcpMessageTypeInvalid, "unexpected MessageType: " + messageType);
        }
    }

    private boolean accumulateChunk(ByteBuf buffer) throws UaException {
        int maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
        int maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();

        int chunkSize = buffer.readerIndex(0).readableBytes();

        if (chunkSize > maxChunkSize) {
            throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
                    String.format("max chunk size exceeded (%s)", maxChunkSize));
        }

        chunkBuffers.add(buffer.retain());

        if (chunkBuffers.size() > maxChunkCount) {
            throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
                    String.format("max chunk count exceeded (%s)", maxChunkCount));
        }

        char chunkType = (char) buffer.getByte(3);

        return (chunkType == 'A' || chunkType == 'F');
    }

    private void onOpenSecureChannel(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
        if (secureChannelTimeout != null) {
            if (secureChannelTimeout.cancel()) {
                logger.debug("OpenSecureChannel timeout canceled");

                secureChannelTimeout = null;
            } else {
                logger.warn("timed out waiting for secure channel");

                handshakeFuture.completeExceptionally(
                        new UaException(StatusCodes.Bad_Timeout, "timed out waiting for secure channel"));
                ctx.close();
                return;
            }
        }

        buffer.skipBytes(3 + 1 + 4 + 4); // skip messageType, chunkType, messageSize, secureChannelId

        AsymmetricSecurityHeader securityHeader = AsymmetricSecurityHeader.decode(buffer);
        if (!headerRef.compareAndSet(null, securityHeader)) {
            if (!securityHeader.equals(headerRef.get())) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed,
                        "subsequent AsymmetricSecurityHeader did not match");
            }
        }

        if (accumulateChunk(buffer)) {
            final List<ByteBuf> buffersToDecode = ImmutableList.copyOf(chunkBuffers);
            chunkBuffers = new LinkedList<>();

            serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
                ByteBuf decodedBuffer = null;

                try {
                    decodedBuffer = chunkDecoder.decodeAsymmetric(secureChannel, buffersToDecode);

                    UaResponseMessage responseMessage = binaryDecoder.setBuffer(decodedBuffer).decodeMessage(null);

                    StatusCode serviceResult = responseMessage.getResponseHeader().getServiceResult();

                    if (serviceResult.isGood()) {
                        OpenSecureChannelResponse response = (OpenSecureChannelResponse) responseMessage;

                        secureChannel.setChannelId(response.getSecurityToken().getChannelId().longValue());
                        logger.debug("Received OpenSecureChannelResponse.");

                        installSecurityToken(ctx, response);

                        handshakeFuture.complete(secureChannel);
                    } else {
                        ServiceFault serviceFault = (responseMessage instanceof ServiceFault)
                                ? (ServiceFault) responseMessage
                                : new ServiceFault(responseMessage.getResponseHeader());

                        throw new UaServiceFaultException(serviceFault);
                    }
                } catch (MessageAbortedException e) {
                    logger.error("Received message abort chunk; error={}, reason={}", e.getStatusCode(),
                            e.getMessage());
                    ctx.close();
                } catch (Throwable t) {
                    logger.error("Error decoding OpenSecureChannelResponse: {}", t.getMessage(), t);
                    ctx.close();
                } finally {
                    if (decodedBuffer != null) {
                        decodedBuffer.release();
                    }
                }
            });
        }
    }

    private void installSecurityToken(ChannelHandlerContext ctx, OpenSecureChannelResponse response) {
        ChannelSecurity.SecuritySecrets newKeys = null;
        if (response.getServerProtocolVersion().longValue() < PROTOCOL_VERSION) {
            throw new UaRuntimeException(StatusCodes.Bad_ProtocolVersionUnsupported,
                    "server protocol version unsupported: " + response.getServerProtocolVersion());
        }

        ChannelSecurityToken newToken = response.getSecurityToken();

        if (secureChannel.isSymmetricSigningEnabled()) {
            secureChannel.setRemoteNonce(response.getServerNonce());

            newKeys = ChannelSecurity.generateKeyPair(secureChannel, secureChannel.getLocalNonce(),
                    secureChannel.getRemoteNonce());
        }

        ChannelSecurity oldSecrets = secureChannel.getChannelSecurity();
        ChannelSecurity.SecuritySecrets oldKeys = oldSecrets != null ? oldSecrets.getCurrentKeys() : null;
        ChannelSecurityToken oldToken = oldSecrets != null ? oldSecrets.getCurrentToken() : null;

        secureChannel.setChannelSecurity(new ChannelSecurity(newKeys, newToken, oldKeys, oldToken));

        DateTime createdAt = response.getSecurityToken().getCreatedAt();
        long revisedLifetime = response.getSecurityToken().getRevisedLifetime().longValue();

        if (revisedLifetime > 0) {
            long renewAt = (long) (revisedLifetime * 0.75);
            renewFuture = ctx.executor().schedule(
                    () -> sendOpenSecureChannelRequest(ctx, SecurityTokenRequestType.Renew), renewAt,
                    TimeUnit.MILLISECONDS);
        } else {
            logger.warn("Server revised secure channel lifetime to 0; renewal will not occur.");
        }

        ctx.executor().execute(() -> {
            // SecureChannel is ready; remove the acknowledge handler.
            if (ctx.pipeline().get(UaTcpClientAcknowledgeHandler.class) != null) {
                ctx.pipeline().remove(UaTcpClientAcknowledgeHandler.class);
            }
        });

        ChannelSecurity channelSecurity = secureChannel.getChannelSecurity();

        long currentTokenId = channelSecurity.getCurrentToken().getTokenId().longValue();

        long previousTokenId = channelSecurity.getPreviousToken().map(t -> t.getTokenId().longValue()).orElse(-1L);

        logger.debug("SecureChannel id={}, currentTokenId={}, previousTokenId={}, lifetime={}ms, createdAt={}",
                secureChannel.getChannelId(), currentTokenId, previousTokenId, revisedLifetime, createdAt);
    }

    private void onSecureMessage(ChannelHandlerContext ctx, ByteBuf buffer) throws UaException {
        buffer.skipBytes(3 + 1 + 4); // skip messageType, chunkType, messageSize

        long secureChannelId = buffer.readUnsignedInt();
        if (secureChannelId != secureChannel.getChannelId()) {
            throw new UaException(StatusCodes.Bad_SecureChannelIdInvalid,
                    "invalid secure channel id: " + secureChannelId);
        }

        if (accumulateChunk(buffer)) {
            final List<ByteBuf> buffersToDecode = ImmutableList.copyOf(chunkBuffers);
            chunkBuffers = new LinkedList<>();

            serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
                ByteBuf decodedBuffer = null;

                try {
                    decodedBuffer = chunkDecoder.decodeSymmetric(secureChannel, buffersToDecode);

                    binaryDecoder.setBuffer(decodedBuffer);
                    UaResponseMessage response = binaryDecoder.decodeMessage(null);

                    UaRequestFuture request = pending.remove(chunkDecoder.getLastRequestId());

                    if (request != null) {
                        client.getExecutorService().execute(() -> request.getFuture().complete(response));
                    } else {
                        logger.warn("No UaRequestFuture for requestId={}", chunkDecoder.getLastRequestId());
                    }
                } catch (MessageAbortedException e) {
                    logger.debug("Received message abort chunk; error={}, reason={}", e.getStatusCode(),
                            e.getMessage());

                    UaRequestFuture request = pending.remove(chunkDecoder.getLastRequestId());

                    if (request != null) {
                        client.getExecutorService().execute(() -> request.getFuture().completeExceptionally(e));
                    } else {
                        logger.warn("No UaRequestFuture for requestId={}", chunkDecoder.getLastRequestId());
                    }
                } catch (Throwable t) {
                    logger.error("Error decoding symmetric message: {}", t.getMessage(), t);
                    serializationQueue.pause();
                    ctx.close();
                } finally {
                    if (decodedBuffer != null) {
                        decodedBuffer.release();
                    }
                }
            });
        }
    }

    private void onError(ChannelHandlerContext ctx, ByteBuf buffer) {
        try {
            ErrorMessage errorMessage = TcpMessageDecoder.decodeError(buffer);
            StatusCode statusCode = errorMessage.getError();
            long errorCode = statusCode.getValue();

            boolean secureChannelError = errorCode == StatusCodes.Bad_SecurityChecksFailed
                    || errorCode == StatusCodes.Bad_TcpSecureChannelUnknown
                    || errorCode == StatusCodes.Bad_SecureChannelIdInvalid;

            if (secureChannelError) {
                secureChannel.setChannelId(0);
            }

            logger.error("Received error message: " + errorMessage);

            handshakeFuture.completeExceptionally(new UaException(statusCode, errorMessage.getReason()));
        } catch (UaException e) {
            logger.error("An exception occurred while decoding an error message: {}", e.getMessage(), e);

            handshakeFuture.completeExceptionally(e);
        } finally {
            ctx.close();
        }
    }

}