org.apache.nifi.toolkit.tls.util.TlsHelperTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.nifi.toolkit.tls.util.TlsHelperTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.nifi.toolkit.tls.util;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.KeyStore;
import java.security.KeyStoreSpi;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.Provider;
import java.security.SignatureException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.security.util.CertificateUtils;
import org.apache.nifi.toolkit.tls.configuration.TlsConfig;
import org.bouncycastle.asn1.pkcs.Attribute;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.x509.Extension;
import org.bouncycastle.asn1.x509.Extensions;
import org.bouncycastle.asn1.x509.GeneralName;
import org.bouncycastle.asn1.x509.GeneralNames;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.operator.OperatorCreationException;
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.AdditionalMatchers;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@RunWith(MockitoJUnitRunner.class)
public class TlsHelperTest {
    public static final Logger logger = LoggerFactory.getLogger(TlsHelperTest.class);

    private static final boolean originalUnlimitedCrypto = TlsHelper.isUnlimitedStrengthCryptographyEnabled();

    private int days;

    private int keySize;

    private String keyPairAlgorithm;

    private String signingAlgorithm;

    private KeyPairGenerator keyPairGenerator;

    private KeyStore keyStore;

    @Mock
    KeyStoreSpi keyStoreSpi;

    @Mock
    Provider keyStoreProvider;

    @Mock
    OutputStreamFactory outputStreamFactory;

    private ByteArrayOutputStream tmpFileOutputStream;

    private File file;

    private static void setUnlimitedCrypto(boolean value) {
        try {
            Field isUnlimitedStrengthCryptographyEnabled = TlsHelper.class
                    .getDeclaredField("isUnlimitedStrengthCryptographyEnabled");
            isUnlimitedStrengthCryptographyEnabled.setAccessible(true);
            isUnlimitedStrengthCryptographyEnabled.set(null, value);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static KeyPair loadKeyPair(Reader reader) throws IOException {
        try (PEMParser pemParser = new PEMParser(reader)) {
            Object object = pemParser.readObject();
            assertEquals(PEMKeyPair.class, object.getClass());
            return new JcaPEMKeyConverter().getKeyPair((PEMKeyPair) object);
        }
    }

    public static KeyPair loadKeyPair(File file) throws IOException {
        return loadKeyPair(new FileReader(file));
    }

    public static X509Certificate loadCertificate(Reader reader) throws IOException, CertificateException {
        try (PEMParser pemParser = new PEMParser(reader)) {
            Object object = pemParser.readObject();
            assertEquals(X509CertificateHolder.class, object.getClass());
            return new JcaX509CertificateConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME)
                    .getCertificate((X509CertificateHolder) object);
        }
    }

    public static X509Certificate loadCertificate(File file) throws IOException, CertificateException {
        return loadCertificate(new FileReader(file));
    }

    @Before
    public void setup() throws Exception {
        days = 360;
        keySize = 2048;
        keyPairAlgorithm = "RSA";
        signingAlgorithm = "SHA1WITHRSA";
        keyPairGenerator = KeyPairGenerator.getInstance(keyPairAlgorithm);
        keyPairGenerator.initialize(keySize);
        Constructor<KeyStore> keyStoreConstructor = KeyStore.class.getDeclaredConstructor(KeyStoreSpi.class,
                Provider.class, String.class);
        keyStoreConstructor.setAccessible(true);
        keyStore = keyStoreConstructor.newInstance(keyStoreSpi, keyStoreProvider, "faketype");
        keyStore.load(null, null);
        file = File.createTempFile("keystore", "file");
        when(outputStreamFactory.create(file)).thenReturn(tmpFileOutputStream);
    }

    @After
    public void tearDown() {
        setUnlimitedCrypto(originalUnlimitedCrypto);
        file.delete();
    }

    private Date inFuture(int days) {
        return new Date(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(days));
    }

    @Test
    public void testGenerateSelfSignedCert()
            throws GeneralSecurityException, IOException, OperatorCreationException {
        String dn = "CN=testDN,O=testOrg";

        X509Certificate x509Certificate = CertificateUtils.generateSelfSignedX509Certificate(
                TlsHelper.generateKeyPair(keyPairAlgorithm, keySize), dn, signingAlgorithm, days);

        Date notAfter = x509Certificate.getNotAfter();
        assertTrue(notAfter.after(inFuture(days - 1)));
        assertTrue(notAfter.before(inFuture(days + 1)));

        Date notBefore = x509Certificate.getNotBefore();
        assertTrue(notBefore.after(inFuture(-1)));
        assertTrue(notBefore.before(inFuture(1)));

        assertEquals(dn, x509Certificate.getIssuerX500Principal().getName());
        assertEquals(signingAlgorithm, x509Certificate.getSigAlgName());
        assertEquals(keyPairAlgorithm, x509Certificate.getPublicKey().getAlgorithm());

        x509Certificate.checkValidity();
    }

    @Test
    public void testIssueCert() throws IOException, CertificateException, NoSuchAlgorithmException,
            OperatorCreationException, NoSuchProviderException, InvalidKeyException, SignatureException {
        X509Certificate issuer = loadCertificate(
                new InputStreamReader(getClass().getClassLoader().getResourceAsStream("rootCert.crt")));
        KeyPair issuerKeyPair = loadKeyPair(
                new InputStreamReader(getClass().getClassLoader().getResourceAsStream("rootCert.key")));

        String dn = "CN=testIssued, O=testOrg";

        KeyPair keyPair = TlsHelper.generateKeyPair(keyPairAlgorithm, keySize);
        X509Certificate x509Certificate = CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(),
                issuer, issuerKeyPair, signingAlgorithm, days);
        assertEquals(dn, x509Certificate.getSubjectX500Principal().toString());
        assertEquals(issuer.getSubjectX500Principal().toString(),
                x509Certificate.getIssuerX500Principal().toString());
        assertEquals(keyPair.getPublic(), x509Certificate.getPublicKey());

        Date notAfter = x509Certificate.getNotAfter();
        assertTrue(notAfter.after(inFuture(days - 1)));
        assertTrue(notAfter.before(inFuture(days + 1)));

        Date notBefore = x509Certificate.getNotBefore();
        assertTrue(notBefore.after(inFuture(-1)));
        assertTrue(notBefore.before(inFuture(1)));

        assertEquals(signingAlgorithm, x509Certificate.getSigAlgName());
        assertEquals(keyPairAlgorithm, x509Certificate.getPublicKey().getAlgorithm());

        x509Certificate.verify(issuerKeyPair.getPublic());
    }

    @Test
    public void testWriteKeyStoreSuccess() throws IOException, GeneralSecurityException {
        setUnlimitedCrypto(false);
        String testPassword = "testPassword";
        assertEquals(testPassword,
                TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, false));
        verify(keyStoreSpi, times(1)).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(testPassword.toCharArray()));
    }

    @Test
    public void testWriteKeyStoreFailure() throws IOException, GeneralSecurityException {
        setUnlimitedCrypto(false);
        String testPassword = "testPassword";
        IOException ioException = new IOException("Fail");
        doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(testPassword.toCharArray()));
        try {
            TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true);
            fail("Expected " + ioException);
        } catch (IOException e) {
            assertEquals(ioException, e);
        }
    }

    @Test
    public void testWriteKeyStoreTruncate() throws IOException, GeneralSecurityException {
        setUnlimitedCrypto(false);
        String testPassword = "testPassword";
        String truncatedPassword = testPassword.substring(0, 7);
        IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE);
        doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(testPassword.toCharArray()));
        assertEquals(truncatedPassword,
                TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true));
        verify(keyStoreSpi, times(1)).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(testPassword.toCharArray()));
        verify(keyStoreSpi, times(1)).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(truncatedPassword.toCharArray()));
    }

    @Test
    public void testWriteKeyStoreUnlimitedWontTruncate() throws GeneralSecurityException, IOException {
        setUnlimitedCrypto(true);
        String testPassword = "testPassword";
        IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE);
        doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(testPassword.toCharArray()));
        try {
            TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true);
            fail("Expected " + ioException);
        } catch (IOException e) {
            assertEquals(ioException, e);
        }
    }

    @Test
    public void testWriteKeyStoreNoTruncate() throws IOException, GeneralSecurityException {
        setUnlimitedCrypto(false);
        String testPassword = "testPassword";
        IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE);
        doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(testPassword.toCharArray()));
        try {
            TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, false);
            fail("Expected " + GeneralSecurityException.class);
        } catch (GeneralSecurityException e) {
            assertTrue("Expected exception to contain " + TlsHelper.JCE_URL,
                    e.getMessage().contains(TlsHelper.JCE_URL));
        }
    }

    @Test
    public void testWriteKeyStoreTruncateFailure() throws IOException, GeneralSecurityException {
        setUnlimitedCrypto(false);
        String testPassword = "testPassword";
        String truncatedPassword = testPassword.substring(0, 7);
        IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE);
        IOException ioException2 = new IOException(TlsHelper.ILLEGAL_KEY_SIZE);
        doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(testPassword.toCharArray()));
        doThrow(ioException2).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream),
                AdditionalMatchers.aryEq(truncatedPassword.toCharArray()));
        try {
            TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true);
            fail("Expected " + ioException2);
        } catch (IOException e) {
            assertEquals(ioException2, e);
        }
    }

    @Test
    public void testShouldIncludeSANFromCSR() throws Exception {
        // Arrange
        final List<String> SAN_ENTRIES = Arrays.asList("127.0.0.1", "nifi.nifi.apache.org");
        final String SAN = StringUtils.join(SAN_ENTRIES, ",");
        final int SAN_COUNT = SAN_ENTRIES.size();
        final String DN = "CN=localhost";
        KeyPair keyPair = keyPairGenerator.generateKeyPair();
        logger.info("Generating CSR with DN: " + DN);

        // Act
        JcaPKCS10CertificationRequest csrWithSan = TlsHelper.generateCertificationRequest(DN, SAN, keyPair,
                TlsConfig.DEFAULT_SIGNING_ALGORITHM);
        logger.info("Created CSR with SAN: " + SAN);
        String testCsrPem = TlsHelper.pemEncodeJcaObject(csrWithSan);
        logger.info("Encoded CSR as PEM: " + testCsrPem);

        // Assert
        String subjectName = csrWithSan.getSubject().toString();
        logger.info("CSR Subject Name: " + subjectName);
        assert subjectName.equals(DN);

        List<String> extractedSans = extractSanFromCsr(csrWithSan);
        assert extractedSans.size() == SAN_COUNT;
        List<String> formattedSans = SAN_ENTRIES.stream().map(s -> "DNS: " + s).collect(Collectors.toList());
        assert extractedSans.containsAll(formattedSans);
    }

    private List<String> extractSanFromCsr(JcaPKCS10CertificationRequest csr) {
        List<String> sans = new ArrayList<>();
        Attribute[] certAttributes = csr.getAttributes();
        for (Attribute attribute : certAttributes) {
            if (attribute.getAttrType().equals(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest)) {
                Extensions extensions = Extensions.getInstance(attribute.getAttrValues().getObjectAt(0));
                GeneralNames gns = GeneralNames.fromExtensions(extensions, Extension.subjectAlternativeName);
                GeneralName[] names = gns.getNames();
                for (GeneralName name : names) {
                    logger.info("Type: " + name.getTagNo() + " | Name: " + name.getName());
                    String title = "";
                    if (name.getTagNo() == GeneralName.dNSName) {
                        title = "DNS";
                    } else if (name.getTagNo() == GeneralName.iPAddress) {
                        title = "IP Address";
                        // name.toASN1Primitive();
                    } else if (name.getTagNo() == GeneralName.otherName) {
                        title = "Other Name";
                    }
                    sans.add(title + ": " + name.getName());
                }
            }
        }

        return sans;
    }
}