com.amazonaws.encryptionsdk.internal.DecryptionHandler.java Source code

Java tutorial

Introduction

Here is the source code for com.amazonaws.encryptionsdk.internal.DecryptionHandler.java

Source

/*
 * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
 * in compliance with the License. A copy of the License is located at
 * 
 * http://aws.amazon.com/apache2.0
 * 
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */

package com.amazonaws.encryptionsdk.internal;

import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import javax.crypto.Cipher;
import javax.crypto.SecretKey;

import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
import org.bouncycastle.jce.ECNamedCurveTable;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
import org.bouncycastle.math.ec.ECPoint;

import com.amazonaws.encryptionsdk.CryptoAlgorithm;
import com.amazonaws.encryptionsdk.DataKey;
import com.amazonaws.encryptionsdk.MasterKey;
import com.amazonaws.encryptionsdk.MasterKeyProvider;
import com.amazonaws.encryptionsdk.exception.AwsCryptoException;
import com.amazonaws.encryptionsdk.exception.BadCiphertextException;
import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
import com.amazonaws.encryptionsdk.model.CiphertextFooters;
import com.amazonaws.encryptionsdk.model.CiphertextHeaders;
import com.amazonaws.encryptionsdk.model.CiphertextType;
import com.amazonaws.encryptionsdk.model.ContentType;
import com.amazonaws.util.Base64;

/**
 * This class implements the CryptoHandler interface by providing methods for
 * the decryption of ciphertext produced by the methods in
 * {@link EncryptionHandler}.
 * 
 * <p>
 * This class reads and parses the values in the ciphertext headers and
 * delegates the decryption of the ciphertext to the
 * {@link BlockDecryptionHandler} or {@link FrameDecryptionHandler} based on the
 * content type parsed in the ciphertext headers.
 */
public class DecryptionHandler<K extends MasterKey<K>> implements MessageCryptoHandler<K> {
    private final MasterKeyProvider<K> masterKeyProvider_;

    private final CiphertextHeaders ciphertextHeaders_;
    private final CiphertextFooters ciphertextFooters_;
    private boolean ciphertextHeadersParsed_;

    private CryptoHandler contentCryptoHandler_;

    private DataKey<K> dataKey_;
    private SecretKey decryptionKey_;
    private CryptoAlgorithm cryptoAlgo_;
    private PublicKey trailingPublicKey_;
    private Signature trailingSig_;

    private Map<String, String> encryptionContext_ = null;

    private byte[] unparsedBytes_ = new byte[0];
    private boolean complete_ = false;

    /**
     * Create a decryption handler using the provided master key.
     * 
     * <p>
     * Note the methods in the provided master key are used in decrypting the
     * encrypted data key parsed from the ciphertext headers.
     * 
     * @param customerMasterKeyProvider
     *            the master key provider to use in picking a master key from
     *            the key blobs encoded in the provided ciphertext.
     * @throws AwsCryptoException
     *             if the master key is null.
     */
    public DecryptionHandler(final MasterKeyProvider<K> customerMasterKeyProvider) throws AwsCryptoException {
        Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider");
        masterKeyProvider_ = customerMasterKeyProvider;
        ciphertextHeaders_ = new CiphertextHeaders();
        ciphertextFooters_ = new CiphertextFooters();
    }

    /**
     * Create a decryption handler using the provided master key and already parsed {@code headers}.
     * 
     * <p>
     * Note the methods in the provided master key are used in decrypting the encrypted data key
     * parsed from the ciphertext headers.
     * 
     * @param customerMasterKeyProvider
     *            the master key provider to use in picking a master key from the key blobs encoded
     *            in the provided ciphertext.
     * @param headers
     *            already parsed headers which will not be passed into
     *            {@link #processBytes(byte[], int, int, byte[], int)}
     * @throws AwsCryptoException
     *             if the master key is null.
     */
    public DecryptionHandler(final MasterKeyProvider<K> customerMasterKeyProvider, final CiphertextHeaders headers)
            throws AwsCryptoException {
        Utils.assertNonNull(customerMasterKeyProvider, "customerMasterKeyProvider");
        masterKeyProvider_ = customerMasterKeyProvider;
        ciphertextHeaders_ = headers;
        ciphertextFooters_ = new CiphertextFooters();
        readHeaderFields(headers);
        updateTrailingSignature(headers);
    }

    /**
     * Decrypt the ciphertext bytes provided in {@code in} and copy the plaintext bytes to
     * {@code out}.
     * 
     * <p>
     * This method consumes and parses the ciphertext headers. The decryption of the actual content
     * is delegated to {@link BlockDecryptionHandler} or {@link FrameDecryptionHandler} based on the
     * content type parsed in the ciphertext header.
     * 
     * @param in
     *            the input byte array.
     * @param off
     *            the offset into the in array where the data to be decrypted starts.
     * @param len
     *            the number of bytes to be decrypted.
     * @param out
     *            the output buffer the decrypted plaintext bytes go into.
     * @param outOff
     *            the offset into the output byte array the decrypted data starts at.
     * @return the number of bytes written to {@code out} and processed.
     * 
     * @throws BadCiphertextException
     *             if the ciphertext header contains invalid entries or if the header integrity
     *             check fails.
     * @throws AwsCryptoException
     *             if any of the offset or length arguments are negative or if the total bytes to
     *             decrypt exceeds the maximum allowed value.
     */
    @Override
    public ProcessingSummary processBytes(final byte[] in, final int off, final int len, final byte[] out,
            final int outOff) throws BadCiphertextException, AwsCryptoException {
        if (len < 0 || off < 0) {
            throw new AwsCryptoException(
                    String.format("Invalid values for input offset: %d and length: %d", off, len));
        }

        if (in.length == 0 || len == 0) {
            return ProcessingSummary.ZERO;
        }

        final long totalBytesToParse = unparsedBytes_.length + (long) len;
        // check for integer overflow
        if (totalBytesToParse > Integer.MAX_VALUE) {
            throw new AwsCryptoException(
                    "Size of the total bytes to parse and decrypt exceeded allowed maximum:" + Integer.MAX_VALUE);
        }

        final byte[] bytesToParse = new byte[(int) totalBytesToParse];
        final int leftoverBytes = unparsedBytes_.length;
        // If there were previously unparsed bytes, add them as the first
        // set of bytes to be parsed in this call.
        System.arraycopy(unparsedBytes_, 0, bytesToParse, 0, unparsedBytes_.length);
        System.arraycopy(in, off, bytesToParse, unparsedBytes_.length, len);

        int totalParsedBytes = 0;
        if (ciphertextHeadersParsed_ == false) {
            totalParsedBytes += ciphertextHeaders_.deserialize(bytesToParse, 0);
            // When ciphertext headers are complete, we have the data
            // key and cipher mode to initialize the underlying cipher
            if (ciphertextHeaders_.isComplete() == true) {
                readHeaderFields(ciphertextHeaders_);
                updateTrailingSignature(ciphertextHeaders_);
                ciphertextHeadersParsed_ = true;
                // reset unparsed bytes as parsing of ciphertext headers is
                // complete.
                unparsedBytes_ = new byte[0];
            } else {
                // If there aren't enough bytes to parse ciphertext
                // headers, we don't have anymore bytes to continue parsing.
                // But first copy the leftover bytes to unparsed bytes.
                unparsedBytes_ = Arrays.copyOfRange(bytesToParse, totalParsedBytes, bytesToParse.length);
                return new ProcessingSummary(0, len);
            }
        }

        int actualOutLen = 0;
        if (!contentCryptoHandler_.isComplete()) {
            // if there are bytes to parse further, pass it off to underlying
            // content cryptohandler.
            if ((bytesToParse.length - totalParsedBytes) > 0) {
                final ProcessingSummary contentResult = contentCryptoHandler_.processBytes(bytesToParse,
                        totalParsedBytes, bytesToParse.length - totalParsedBytes, out, outOff);
                updateTrailingSignature(bytesToParse, totalParsedBytes, contentResult.getBytesProcessed());
                actualOutLen = contentResult.getBytesWritten();
                totalParsedBytes += contentResult.getBytesProcessed();
            }
            if (contentCryptoHandler_.isComplete()) {
                actualOutLen += contentCryptoHandler_.doFinal(out, outOff + actualOutLen);
            }
        }
        if (contentCryptoHandler_.isComplete()) {
            totalParsedBytes += ciphertextFooters_.deserialize(bytesToParse, totalParsedBytes);
            if (ciphertextFooters_.isComplete() && trailingSig_ != null) {
                try {
                    if (!trailingSig_.verify(ciphertextFooters_.getMAuth())) {
                        throw new BadCiphertextException("Bad trailing signature");
                    }
                    complete_ = true;
                } catch (final SignatureException ex) {
                    throw new BadCiphertextException("Bad trailing signature", ex);
                }
            }
        }

        return new ProcessingSummary(actualOutLen, totalParsedBytes - leftoverBytes);
    }

    /**
     * Finish processing of the bytes.
     * 
     * @param out
     *            space for any resulting output data.
     * @param outOff
     *            offset into {@code out} to start copying the data at.
     * @return
     *         number of bytes written into {@code out}.
     * @throws BadCiphertextException
     *             if the bytes do not decrypt correctly.
     */
    @Override
    public int doFinal(final byte[] out, final int outOff) throws BadCiphertextException {
        // check if cryptohandler for content has been created. There are cases
        // when it might not have been created such as when doFinal() is called
        // before the ciphertext headers are fully received and parsed.
        if (contentCryptoHandler_ == null) {
            return 0;
        } else {
            int result = contentCryptoHandler_.doFinal(out, outOff);

            return result;
        }
    }

    /**
     * Return the size of the output buffer required for a
     * <code>processBytes</code> plus a <code>doFinal</code> with an input of
     * inLen bytes.
     * 
     * @param inLen
     *            the length of the input.
     * @return
     *         the space required to accommodate a call to processBytes and
     *         doFinal with input of size {@code inLen} bytes.
     */
    @Override
    public int estimateOutputSize(final int inLen) {
        if (contentCryptoHandler_ != null) {
            return contentCryptoHandler_.estimateOutputSize(inLen);
        } else {
            return (inLen > 0) ? inLen : 0;
        }
    }

    /**
     * Return the encryption context. This value is parsed from the ciphertext.
     * 
     * @return
     *         the key-value map containing the encryption client.
     */
    @Override
    public Map<String, String> getEncryptionContext() {
        return encryptionContext_;
    }

    /**
     * Check integrity of the header bytes by processing the parsed MAC tag in
     * the headers through the cipher.
     * 
     * @param ciphertextHeaders
     *            the ciphertext headers object whose integrity needs to be
     *            checked.
     * @return
     *         true if the integrity of the header is intact; false otherwise.
     */
    private void verifyHeaderIntegrity(final CiphertextHeaders ciphertextHeaders) throws BadCiphertextException {
        final CipherHandler cipherHandler = new CipherHandler(decryptionKey_, ciphertextHeaders.getHeaderNonce(),
                ciphertextHeaders.serializeAuthenticatedFields(), Cipher.DECRYPT_MODE, cryptoAlgo_);

        try {
            final byte[] headerTag = ciphertextHeaders.getHeaderTag();
            cipherHandler.cipherData(headerTag, 0, headerTag.length);
        } catch (BadCiphertextException e) {
            throw new BadCiphertextException("Header integrity check failed.", e);
        }
    }

    /**
     * Retrieve the data key from the ciphertext headers. This method calls the
     * decryptDataKey() method of the customer master key to decrypt the
     * encrypted bytes of the data key read from the ciphertext headers.
     * 
     * @param ciphertextHeaders
     *            the ciphertext headers object from where the encrypted data
     *            key is to be read.
     * @return
     *         the data key object containing the key in cleartext and encrypted
     *         form.
     */
    private DataKey<K> getDataKey(final CiphertextHeaders ciphertextHeaders) {
        final DataKey<K> result = masterKeyProvider_.decryptDataKey(cryptoAlgo_,
                ciphertextHeaders.getEncryptedKeyBlobs(), ciphertextHeaders.getEncryptionContextMap());

        if (result == null) {
            throw new CannotUnwrapDataKeyException("Could not decrypt any data keys");
        }

        return result;
    }

    /**
     * Read the fields in the ciphertext headers to populate the corresponding
     * instance variables used during decryption.
     * 
     * @param ciphertextHeaders
     *            the ciphertext headers object to read.
     */
    private void readHeaderFields(final CiphertextHeaders ciphertextHeaders) {
        final byte version = ciphertextHeaders.getVersion();
        if (version != VersionInfo.CURRENT_CIPHERTEXT_VERSION) {
            throw new BadCiphertextException("Invalid version in ciphertext.");
        }

        cryptoAlgo_ = ciphertextHeaders.getCryptoAlgoId();

        final CiphertextType ciphertextType = ciphertextHeaders.getType();
        if (ciphertextType != CiphertextType.CUSTOMER_AUTHENTICATED_ENCRYPTED_DATA) {
            throw new BadCiphertextException("Invalid type in ciphertext.");
        }

        final byte[] messageId = ciphertextHeaders.getMessageId();

        encryptionContext_ = ciphertextHeaders.getEncryptionContextMap();
        if (cryptoAlgo_.getTrailingSignatureLength() > 0) {
            try {
                trailingPublicKey_ = deserializeTrailingKeyFromEc(
                        encryptionContext_.get(Constants.EC_PUBLIC_KEY_FIELD));
                trailingSig_ = Signature.getInstance(cryptoAlgo_.getTrailingSignatureAlgo(), "BC");
                trailingSig_.initVerify(trailingPublicKey_);
            } catch (final GeneralSecurityException ex) {
                throw new AwsCryptoException(ex);
            }
        } else {
            trailingPublicKey_ = null;
            trailingSig_ = null;
        }

        final ContentType contentType = ciphertextHeaders.getContentType();

        final short nonceLen = ciphertextHeaders.getNonceLength();
        final int frameLen = ciphertextHeaders.getFrameLength();

        dataKey_ = getDataKey(ciphertextHeaders);
        try {
            decryptionKey_ = cryptoAlgo_.getEncryptionKeyFromDataKey(dataKey_.getKey(), ciphertextHeaders);
        } catch (final InvalidKeyException ex) {
            throw new AwsCryptoException(ex);
        }

        verifyHeaderIntegrity(ciphertextHeaders);

        switch (contentType) {
        case FRAME:
            contentCryptoHandler_ = new FrameDecryptionHandler(decryptionKey_, (byte) nonceLen, cryptoAlgo_,
                    messageId, frameLen);
            break;
        case SINGLEBLOCK:
            contentCryptoHandler_ = new BlockDecryptionHandler(decryptionKey_, (byte) nonceLen, cryptoAlgo_,
                    messageId);
            break;
        default:
            // should never get here because an invalid content type is
            // detected when parsing.
            break;
        }
    }

    private PublicKey deserializeTrailingKeyFromEc(final String pubKey) throws GeneralSecurityException {
        final ECNamedCurveParameterSpec ecSpec;

        switch (cryptoAlgo_) {
        case ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256:
            ecSpec = ECNamedCurveTable.getParameterSpec("secp256r1");
            break;
        case ALG_AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384:
        case ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384:
            ecSpec = ECNamedCurveTable.getParameterSpec("secp384r1");
            break;
        default:
            throw new IllegalStateException("Algorithm does not support trailing signature");
        }
        final ECPoint q = ecSpec.getCurve().decodePoint(Base64.decode(pubKey));
        ECPublicKeyParameters keyParams = new ECPublicKeyParameters(q,
                new ECDomainParameters(ecSpec.getCurve(), ecSpec.getG(), ecSpec.getN(), ecSpec.getH()));
        return new BCECPublicKey("ECDSA", keyParams, ecSpec, BouncyCastleProvider.CONFIGURATION);
    }

    private void updateTrailingSignature(final CiphertextHeaders headers) {
        if (trailingSig_ != null) {
            final byte[] reserializedHeaders = ciphertextHeaders_.toByteArray();
            updateTrailingSignature(reserializedHeaders, 0, reserializedHeaders.length);
        }
    }

    private void updateTrailingSignature(byte[] input, int offset, int len) {
        if (trailingSig_ != null) {
            try {
                trailingSig_.update(input, offset, len);
            } catch (final SignatureException ex) {
                throw new AwsCryptoException(ex);
            }
        }
    }

    @Override
    public CiphertextHeaders getHeaders() {
        return ciphertextHeaders_;
    }

    @Override
    public List<K> getMasterKeys() {
        return Collections.singletonList(dataKey_.getMasterKey());
    }

    @Override
    public boolean isComplete() {
        return complete_;
    }
}