com.digitalpetri.opcua.stack.core.channel.ChunkEncoder.java Source code

Java tutorial

Introduction

Here is the source code for com.digitalpetri.opcua.stack.core.channel.ChunkEncoder.java

Source

/*
 * Copyright 2015 Kevin Herron
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.digitalpetri.opcua.stack.core.channel;

import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.List;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import com.digitalpetri.opcua.stack.core.StatusCodes;
import com.digitalpetri.opcua.stack.core.UaException;
import com.digitalpetri.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import com.digitalpetri.opcua.stack.core.channel.headers.SecureMessageHeader;
import com.digitalpetri.opcua.stack.core.channel.headers.SequenceHeader;
import com.digitalpetri.opcua.stack.core.channel.headers.SymmetricSecurityHeader;
import com.digitalpetri.opcua.stack.core.channel.messages.MessageType;
import com.digitalpetri.opcua.stack.core.security.SecurityAlgorithm;
import com.digitalpetri.opcua.stack.core.util.BufferUtil;
import com.digitalpetri.opcua.stack.core.util.LongSequence;
import com.digitalpetri.opcua.stack.core.util.SignatureUtil;
import io.netty.buffer.ByteBuf;

public class ChunkEncoder {

    private final Delegate asymmetricDelegate = new AsymmetricDelegate();
    private final Delegate symmetricDelegate = new SymmetricDelegate();

    // Wrap after UInt32.MAX - 1024
    private final LongSequence sequenceNumber = new LongSequence(1L, 4294966271L);

    private volatile long lastRequestId = 1L;

    private final ChannelParameters parameters;

    public ChunkEncoder(ChannelParameters parameters) {
        this.parameters = parameters;
    }

    public List<ByteBuf> encodeAsymmetric(SecureChannel channel, MessageType messageType, ByteBuf messageBuffer,
            long requestId) throws UaException {

        return encode(asymmetricDelegate, channel, messageType, messageBuffer, requestId);
    }

    public List<ByteBuf> encodeSymmetric(SecureChannel channel, MessageType messageType, ByteBuf messageBuffer,
            long requestId) throws UaException {

        return encode(symmetricDelegate, channel, messageType, messageBuffer, requestId);
    }

    private List<ByteBuf> encode(Delegate delegate, SecureChannel channel, MessageType messageType,
            ByteBuf messageBuffer, long requestId) throws UaException {

        List<ByteBuf> chunks = new ArrayList<>();

        boolean encrypted = delegate.isEncryptionEnabled(channel);

        int securityHeaderSize = delegate.getSecurityHeaderSize(channel);
        int cipherTextBlockSize = delegate.getCipherTextBlockSize(channel);
        int plainTextBlockSize = delegate.getPlainTextBlockSize(channel);
        int signatureSize = delegate.getSignatureSize(channel);

        int maxChunkSize = parameters.getLocalSendBufferSize();
        int headerSizes = SecureMessageHeader.SECURE_MESSAGE_HEADER_SIZE + securityHeaderSize;
        int paddingOverhead = encrypted ? (cipherTextBlockSize > 256 ? 2 : 1) : 0;

        int maxBlockCount = (maxChunkSize - headerSizes - signatureSize - paddingOverhead) / cipherTextBlockSize;
        int maxBodySize = (plainTextBlockSize * maxBlockCount - SequenceHeader.SEQUENCE_HEADER_SIZE);

        while (messageBuffer.readableBytes() > 0) {
            int bodySize = Math.min(messageBuffer.readableBytes(), maxBodySize);

            int paddingSize = encrypted
                    ? plainTextBlockSize
                            - (SequenceHeader.SEQUENCE_HEADER_SIZE + bodySize + signatureSize + paddingOverhead)
                                    % plainTextBlockSize
                    : 0;

            int plainTextContentSize = SequenceHeader.SEQUENCE_HEADER_SIZE + bodySize + signatureSize + paddingSize
                    + paddingOverhead;

            assert (plainTextContentSize % plainTextBlockSize == 0);

            int chunkSize = SecureMessageHeader.SECURE_MESSAGE_HEADER_SIZE + securityHeaderSize
                    + (plainTextContentSize / plainTextBlockSize) * cipherTextBlockSize;

            ByteBuf chunkBuffer = BufferUtil.buffer(chunkSize);

            /* Message Header */
            SecureMessageHeader messageHeader = new SecureMessageHeader(messageType,
                    messageBuffer.readableBytes() > bodySize ? 'C' : 'F', chunkSize, channel.getChannelId());

            SecureMessageHeader.encode(messageHeader, chunkBuffer);

            /* Security Header */
            delegate.encodeSecurityHeader(channel, chunkBuffer);

            /* Sequence Header */
            SequenceHeader sequenceHeader = new SequenceHeader(sequenceNumber.getAndIncrement(), requestId);

            SequenceHeader.encode(sequenceHeader, chunkBuffer);

            /* Message Body */
            chunkBuffer.writeBytes(messageBuffer, bodySize);

            /* Padding and Signature */
            if (encrypted) {
                writePadding(cipherTextBlockSize, paddingSize, chunkBuffer);
            }

            if (delegate.isSigningEnabled(channel)) {
                ByteBuffer chunkNioBuffer = chunkBuffer.nioBuffer(0, chunkBuffer.writerIndex());

                byte[] signature = delegate.signChunk(channel, chunkNioBuffer);

                chunkBuffer.writeBytes(signature);
            }

            /* Encryption */
            if (encrypted) {
                chunkBuffer.readerIndex(SecureMessageHeader.SECURE_MESSAGE_HEADER_SIZE + securityHeaderSize);

                assert (chunkBuffer.readableBytes() % plainTextBlockSize == 0);

                try {
                    int blockCount = chunkBuffer.readableBytes() / plainTextBlockSize;

                    ByteBuffer chunkNioBuffer = chunkBuffer.nioBuffer(chunkBuffer.readerIndex(),
                            blockCount * cipherTextBlockSize);
                    ByteBuf copyBuffer = chunkBuffer.copy();
                    ByteBuffer plainTextNioBuffer = copyBuffer.nioBuffer();

                    Cipher cipher = delegate.getAndInitializeCipher(channel);

                    if (delegate instanceof AsymmetricDelegate) {
                        for (int blockNumber = 0; blockNumber < blockCount; blockNumber++) {
                            int position = blockNumber * plainTextBlockSize;
                            int limit = (blockNumber + 1) * plainTextBlockSize;
                            plainTextNioBuffer.position(position).limit(limit);

                            int bytesWritten = cipher.doFinal(plainTextNioBuffer, chunkNioBuffer);

                            assert (bytesWritten == cipherTextBlockSize);
                        }
                    } else {
                        cipher.doFinal(plainTextNioBuffer, chunkNioBuffer);
                    }

                    copyBuffer.release();
                } catch (GeneralSecurityException e) {
                    throw new UaException(StatusCodes.Bad_SecurityChecksFailed, e);
                }
            }

            chunkBuffer.readerIndex(0).writerIndex(chunkSize);

            chunks.add(chunkBuffer);
        }

        lastRequestId = requestId;

        return chunks;
    }

    public long getLastRequestId() {
        return lastRequestId;
    }

    private void writePadding(int cipherTextBlockSize, int paddingSize, ByteBuf buffer) {
        if (cipherTextBlockSize > 256) {
            buffer.writeShort(paddingSize);
        } else {
            buffer.writeByte(paddingSize);
        }

        for (int i = 0; i < paddingSize; i++) {
            buffer.writeByte(paddingSize);
        }

        if (cipherTextBlockSize > 256) {
            // Replace the last byte with the MSB of the 2-byte padding length
            int paddingLengthMSB = paddingSize >> 8;
            buffer.writerIndex(buffer.writerIndex() - 1);
            buffer.writeByte(paddingLengthMSB);
        }
    }

    private static interface Delegate {
        byte[] signChunk(SecureChannel channel, ByteBuffer chunkNioBuffer) throws UaException;

        void encodeSecurityHeader(SecureChannel channel, ByteBuf buffer) throws UaException;

        Cipher getAndInitializeCipher(SecureChannel channel) throws UaException;

        int getSecurityHeaderSize(SecureChannel channel) throws UaException;

        int getCipherTextBlockSize(SecureChannel channel);

        int getPlainTextBlockSize(SecureChannel channel);

        int getSignatureSize(SecureChannel channel);

        boolean isEncryptionEnabled(SecureChannel channel);

        boolean isSigningEnabled(SecureChannel channel);

    }

    private static class AsymmetricDelegate implements Delegate {

        @Override
        public byte[] signChunk(SecureChannel channel, ByteBuffer chunkNioBuffer) throws UaException {
            return SignatureUtil.sign(channel.getSecurityPolicy().getAsymmetricSignatureAlgorithm(),
                    channel.getKeyPair().getPrivate(), chunkNioBuffer);
        }

        @Override
        public Cipher getAndInitializeCipher(SecureChannel channel) throws UaException {
            Certificate remoteCertificate = channel.getRemoteCertificate();

            assert (remoteCertificate != null);

            try {
                String transformation = channel.getSecurityPolicy().getAsymmetricEncryptionAlgorithm()
                        .getTransformation();
                Cipher cipher = Cipher.getInstance(transformation);
                cipher.init(Cipher.ENCRYPT_MODE, remoteCertificate.getPublicKey());
                return cipher;
            } catch (GeneralSecurityException e) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, e);
            }
        }

        @Override
        public void encodeSecurityHeader(SecureChannel channel, ByteBuf buffer) throws UaException {
            AsymmetricSecurityHeader header = new AsymmetricSecurityHeader(
                    channel.getSecurityPolicy().getSecurityPolicyUri(), channel.getLocalCertificateBytes(),
                    channel.getRemoteCertificateThumbprint());

            AsymmetricSecurityHeader.encode(header, buffer);
        }

        @Override
        public int getSecurityHeaderSize(SecureChannel channel) throws UaException {
            String securityPolicyUri = channel.getSecurityPolicy().getSecurityPolicyUri();
            byte[] localCertificateBytes = channel.getLocalCertificateBytes().bytes();
            byte[] remoteCertificateThumbprint = channel.getRemoteCertificateThumbprint().bytes();

            return 12 + securityPolicyUri.length()
                    + (localCertificateBytes != null ? localCertificateBytes.length : 0)
                    + (remoteCertificateThumbprint != null ? remoteCertificateThumbprint.length : 0);
        }

        @Override
        public int getCipherTextBlockSize(SecureChannel channel) {
            return channel.getRemoteAsymmetricCipherTextBlockSize();
        }

        @Override
        public int getPlainTextBlockSize(SecureChannel channel) {
            return channel.getRemoteAsymmetricPlainTextBlockSize();
        }

        @Override
        public int getSignatureSize(SecureChannel channel) {
            return channel.getLocalAsymmetricSignatureSize();
        }

        @Override
        public boolean isEncryptionEnabled(SecureChannel channel) {
            return channel.isAsymmetricEncryptionEnabled();
        }

        @Override
        public boolean isSigningEnabled(SecureChannel channel) {
            return channel.isAsymmetricSigningEnabled();
        }

    }

    private static class SymmetricDelegate implements Delegate {

        private volatile ChannelSecurity.SecuritySecrets securitySecrets;

        @Override
        public void encodeSecurityHeader(SecureChannel channel, ByteBuf buffer) {
            ChannelSecurity channelSecurity = channel.getChannelSecurity();
            long tokenId = channelSecurity != null ? channelSecurity.getCurrentToken().getTokenId().longValue()
                    : 0L;

            SymmetricSecurityHeader.encode(new SymmetricSecurityHeader(tokenId), buffer);

            securitySecrets = channelSecurity != null ? channelSecurity.getCurrentKeys() : null;
        }

        @Override
        public byte[] signChunk(SecureChannel channel, ByteBuffer chunkNioBuffer) throws UaException {
            SecurityAlgorithm signatureAlgorithm = channel.getSecurityPolicy().getSymmetricSignatureAlgorithm();
            byte[] signatureKey = channel.getEncryptionKeys(securitySecrets).getSignatureKey();

            return SignatureUtil.hmac(signatureAlgorithm, signatureKey, chunkNioBuffer);
        }

        @Override
        public Cipher getAndInitializeCipher(SecureChannel channel) throws UaException {
            try {
                String transformation = channel.getSecurityPolicy().getSymmetricEncryptionAlgorithm()
                        .getTransformation();
                ChannelSecurity.SecretKeys secretKeys = channel.getEncryptionKeys(securitySecrets);

                SecretKeySpec keySpec = new SecretKeySpec(secretKeys.getEncryptionKey(), "AES");
                IvParameterSpec ivSpec = new IvParameterSpec(secretKeys.getInitializationVector());

                Cipher cipher = Cipher.getInstance(transformation);
                cipher.init(Cipher.ENCRYPT_MODE, keySpec, ivSpec);

                assert (cipher.getBlockSize() == channel.getSymmetricCipherTextBlockSize());

                return cipher;
            } catch (GeneralSecurityException e) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, e);
            }
        }

        @Override
        public int getSecurityHeaderSize(SecureChannel channel) {
            return SymmetricSecurityHeader.SYMMETRIC_SECURITY_HEADER_SIZE;
        }

        @Override
        public int getCipherTextBlockSize(SecureChannel channel) {
            return channel.getSymmetricCipherTextBlockSize();
        }

        @Override
        public int getPlainTextBlockSize(SecureChannel channel) {
            return channel.getSymmetricPlainTextBlockSize();
        }

        @Override
        public int getSignatureSize(SecureChannel channel) {
            return channel.getSymmetricSignatureSize();
        }

        @Override
        public boolean isEncryptionEnabled(SecureChannel channel) {
            return channel.isSymmetricEncryptionEnabled();
        }

        @Override
        public boolean isSigningEnabled(SecureChannel channel) {
            return channel.isSymmetricSigningEnabled();
        }

    }

}