org.apache.hadoop.crypto.TestCryptoCodec.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.crypto.TestCryptoCodec.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.crypto;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.RandomDatum;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.util.NativeCodeLoader;
import org.apache.hadoop.util.ReflectionUtils;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;

import com.google.common.primitives.Longs;

public class TestCryptoCodec {
    private static final Log LOG = LogFactory.getLog(TestCryptoCodec.class);
    private static byte[] key = new byte[16];
    private static byte[] iv = new byte[16];
    private static final int bufferSize = 4096;

    private Configuration conf = new Configuration();
    private int count = 10000;
    private int seed = new Random().nextInt();
    private final String jceCodecClass = "org.apache.hadoop.crypto.JceAesCtrCryptoCodec";
    private final String opensslCodecClass = "org.apache.hadoop.crypto.OpensslAesCtrCryptoCodec";

    @Before
    public void setUp() throws IOException {
        Random random = new SecureRandom();
        random.nextBytes(key);
        random.nextBytes(iv);
    }

    @Test(timeout = 120000)
    public void testJceAesCtrCryptoCodec() throws Exception {
        GenericTestUtils.assumeInNativeProfile();
        if (!NativeCodeLoader.buildSupportsOpenssl()) {
            LOG.warn("Skipping test since openSSL library not loaded");
            Assume.assumeTrue(false);
        }
        Assert.assertEquals(null, OpensslCipher.getLoadingFailureReason());
        cryptoCodecTest(conf, seed, 0, jceCodecClass, jceCodecClass, iv);
        cryptoCodecTest(conf, seed, count, jceCodecClass, jceCodecClass, iv);
        cryptoCodecTest(conf, seed, count, jceCodecClass, opensslCodecClass, iv);
        // Overflow test, IV: xx xx xx xx xx xx xx xx ff ff ff ff ff ff ff ff 
        for (int i = 0; i < 8; i++) {
            iv[8 + i] = (byte) 0xff;
        }
        cryptoCodecTest(conf, seed, count, jceCodecClass, jceCodecClass, iv);
        cryptoCodecTest(conf, seed, count, jceCodecClass, opensslCodecClass, iv);
    }

    @Test(timeout = 120000)
    public void testOpensslAesCtrCryptoCodec() throws Exception {
        GenericTestUtils.assumeInNativeProfile();
        if (!NativeCodeLoader.buildSupportsOpenssl()) {
            LOG.warn("Skipping test since openSSL library not loaded");
            Assume.assumeTrue(false);
        }
        Assert.assertEquals(null, OpensslCipher.getLoadingFailureReason());
        cryptoCodecTest(conf, seed, 0, opensslCodecClass, opensslCodecClass, iv);
        cryptoCodecTest(conf, seed, count, opensslCodecClass, opensslCodecClass, iv);
        cryptoCodecTest(conf, seed, count, opensslCodecClass, jceCodecClass, iv);
        // Overflow test, IV: xx xx xx xx xx xx xx xx ff ff ff ff ff ff ff ff 
        for (int i = 0; i < 8; i++) {
            iv[8 + i] = (byte) 0xff;
        }
        cryptoCodecTest(conf, seed, count, opensslCodecClass, opensslCodecClass, iv);
        cryptoCodecTest(conf, seed, count, opensslCodecClass, jceCodecClass, iv);
    }

    private void cryptoCodecTest(Configuration conf, int seed, int count, String encCodecClass,
            String decCodecClass, byte[] iv) throws IOException, GeneralSecurityException {
        CryptoCodec encCodec = null;
        try {
            encCodec = (CryptoCodec) ReflectionUtils.newInstance(conf.getClassByName(encCodecClass), conf);
        } catch (ClassNotFoundException cnfe) {
            throw new IOException("Illegal crypto codec!");
        }
        LOG.info("Created a Codec object of type: " + encCodecClass);

        // Generate data
        DataOutputBuffer data = new DataOutputBuffer();
        RandomDatum.Generator generator = new RandomDatum.Generator(seed);
        for (int i = 0; i < count; ++i) {
            generator.next();
            RandomDatum key = generator.getKey();
            RandomDatum value = generator.getValue();

            key.write(data);
            value.write(data);
        }
        LOG.info("Generated " + count + " records");

        // Encrypt data
        DataOutputBuffer encryptedDataBuffer = new DataOutputBuffer();
        CryptoOutputStream out = new CryptoOutputStream(encryptedDataBuffer, encCodec, bufferSize, key, iv);
        out.write(data.getData(), 0, data.getLength());
        out.flush();
        out.close();
        LOG.info("Finished encrypting data");

        CryptoCodec decCodec = null;
        try {
            decCodec = (CryptoCodec) ReflectionUtils.newInstance(conf.getClassByName(decCodecClass), conf);
        } catch (ClassNotFoundException cnfe) {
            throw new IOException("Illegal crypto codec!");
        }
        LOG.info("Created a Codec object of type: " + decCodecClass);

        // Decrypt data
        DataInputBuffer decryptedDataBuffer = new DataInputBuffer();
        decryptedDataBuffer.reset(encryptedDataBuffer.getData(), 0, encryptedDataBuffer.getLength());
        CryptoInputStream in = new CryptoInputStream(decryptedDataBuffer, decCodec, bufferSize, key, iv);
        DataInputStream dataIn = new DataInputStream(new BufferedInputStream(in));

        // Check
        DataInputBuffer originalData = new DataInputBuffer();
        originalData.reset(data.getData(), 0, data.getLength());
        DataInputStream originalIn = new DataInputStream(new BufferedInputStream(originalData));

        for (int i = 0; i < count; ++i) {
            RandomDatum k1 = new RandomDatum();
            RandomDatum v1 = new RandomDatum();
            k1.readFields(originalIn);
            v1.readFields(originalIn);

            RandomDatum k2 = new RandomDatum();
            RandomDatum v2 = new RandomDatum();
            k2.readFields(dataIn);
            v2.readFields(dataIn);
            assertTrue("original and encrypted-then-decrypted-output not equal", k1.equals(k2) && v1.equals(v2));

            // original and encrypted-then-decrypted-output have the same hashCode
            Map<RandomDatum, String> m = new HashMap<RandomDatum, String>();
            m.put(k1, k1.toString());
            m.put(v1, v1.toString());
            String result = m.get(k2);
            assertEquals("k1 and k2 hashcode not equal", result, k1.toString());
            result = m.get(v2);
            assertEquals("v1 and v2 hashcode not equal", result, v1.toString());
        }

        // Decrypt data byte-at-a-time
        originalData.reset(data.getData(), 0, data.getLength());
        decryptedDataBuffer.reset(encryptedDataBuffer.getData(), 0, encryptedDataBuffer.getLength());
        in = new CryptoInputStream(decryptedDataBuffer, decCodec, bufferSize, key, iv);

        // Check
        originalIn = new DataInputStream(new BufferedInputStream(originalData));
        int expected;
        do {
            expected = originalIn.read();
            assertEquals("Decrypted stream read by byte does not match", expected, in.read());
        } while (expected != -1);

        // Seek to a certain position and decrypt
        originalData.reset(data.getData(), 0, data.getLength());
        decryptedDataBuffer.reset(encryptedDataBuffer.getData(), 0, encryptedDataBuffer.getLength());
        in = new CryptoInputStream(new TestCryptoStreams.FakeInputStream(decryptedDataBuffer), decCodec, bufferSize,
                key, iv);
        int seekPos = data.getLength() / 3;
        in.seek(seekPos);

        // Check
        TestCryptoStreams.FakeInputStream originalInput = new TestCryptoStreams.FakeInputStream(originalData);
        originalInput.seek(seekPos);
        do {
            expected = originalInput.read();
            assertEquals("Decrypted stream read by byte does not match", expected, in.read());
        } while (expected != -1);

        LOG.info("SUCCESS! Completed checking " + count + " records");

        // Check secure random generator
        testSecureRandom(encCodec);
    }

    /** Test secure random generator */
    private void testSecureRandom(CryptoCodec codec) {
        // len = 16
        checkSecureRandom(codec, 16);
        // len = 32
        checkSecureRandom(codec, 32);
        // len = 128
        checkSecureRandom(codec, 128);
    }

    private void checkSecureRandom(CryptoCodec codec, int len) {
        byte[] rand = new byte[len];
        byte[] rand1 = new byte[len];
        codec.generateSecureRandom(rand);
        codec.generateSecureRandom(rand1);

        Assert.assertEquals(len, rand.length);
        Assert.assertEquals(len, rand1.length);
        Assert.assertFalse(Arrays.equals(rand, rand1));
    }

    /**
     * Regression test for IV calculation, see HADOOP-11343
     */
    @Test(timeout = 120000)
    public void testCalculateIV() throws Exception {
        JceAesCtrCryptoCodec codec = new JceAesCtrCryptoCodec();
        codec.setConf(conf);

        SecureRandom sr = new SecureRandom();
        byte[] initIV = new byte[16];
        byte[] IV = new byte[16];

        long iterations = 1000;
        long counter = 10000;

        // Overflow test, IV: 00 00 00 00 00 00 00 00 ff ff ff ff ff ff ff ff 
        for (int i = 0; i < 8; i++) {
            initIV[8 + i] = (byte) 0xff;
        }

        for (long j = 0; j < counter; j++) {
            assertIVCalculation(codec, initIV, j, IV);
        }

        // Random IV and counter sequence test
        for (long i = 0; i < iterations; i++) {
            sr.nextBytes(initIV);

            for (long j = 0; j < counter; j++) {
                assertIVCalculation(codec, initIV, j, IV);
            }
        }

        // Random IV and random counter test
        for (long i = 0; i < iterations; i++) {
            sr.nextBytes(initIV);

            for (long j = 0; j < counter; j++) {
                long c = sr.nextLong();
                assertIVCalculation(codec, initIV, c, IV);
            }
        }
    }

    private void assertIVCalculation(CryptoCodec codec, byte[] initIV, long counter, byte[] IV) {
        codec.calculateIV(initIV, counter, IV);

        BigInteger iv = new BigInteger(1, IV);
        BigInteger ref = calculateRef(initIV, counter);

        assertTrue("Calculated IV don't match with the reference", iv.equals(ref));
    }

    private static BigInteger calculateRef(byte[] initIV, long counter) {
        byte[] cb = Longs.toByteArray(counter);
        BigInteger bi = new BigInteger(1, initIV);
        return bi.add(new BigInteger(1, cb));
    }
}