org.apache.nifi.processors.standard.util.crypto.scrypt.Scrypt.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.nifi.processors.standard.util.crypto.scrypt.Scrypt.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.nifi.processors.standard.util.crypto.scrypt;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.processors.standard.util.crypto.CipherUtility;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;

import static java.lang.Integer.MAX_VALUE;
import static java.lang.System.arraycopy;

/**
 * Copyright (C) 2011 - Will Glozer.  All rights reserved.
 * <p/>
 * Taken from Will Glozer's port of Colin Percival's C implementation. Glozer's project located at <a href="https://github.com/wg/scrypt">https://github.com/wg/scrypt</a> was released under the ASF
 * 2.0 license and has not been updated since May 25, 2013 and there are outstanding issues which have been patched in this version.
 * <p/>
 * An implementation of the <a href="http://www.tarsnap.com/scrypt/scrypt.pdf">scrypt</a>
 * key derivation function.
 * <p/>
 * Allows for hashing passwords using the
 * <a href="http://www.tarsnap.com/scrypt.html">scrypt</a> key derivation function
 * and comparing a plain text password to a hashed one.
 */
public class Scrypt {
    private static final Logger logger = LoggerFactory.getLogger(Scrypt.class);

    private static final int DEFAULT_SALT_LENGTH = 16;

    /**
     * Hash the supplied plaintext password and generate output in the format described
     * below:
     * <p/>
     * The hashed output is an
     * extended implementation of the Modular Crypt Format that also includes the scrypt
     * algorithm parameters.
     * <p/>
     * Format: <code>$s0$PARAMS$SALT$KEY</code>.
     * <p/>
     * <dl>
     * <dd>PARAMS</dd><dt>32-bit hex integer containing log2(N) (16 bits), r (8 bits), and p (8 bits)</dt>
     * <dd>SALT</dd><dt>base64-encoded salt</dt>
     * <dd>KEY</dd><dt>base64-encoded derived key</dt>
     * </dl>
     * <p/>
     * <code>s0</code> identifies version 0 of the scrypt format, using a 128-bit salt and 256-bit derived key.
     * <p/>
     * This method generates a 16 byte random salt internally.
     *
     * @param password password
     * @param n        CPU cost parameter
     * @param r        memory cost parameter
     * @param p        parallelization parameter
     * @param dkLen    the desired key length in bits
     * @return the hashed password
     */
    public static String scrypt(String password, int n, int r, int p, int dkLen) {
        byte[] salt = new byte[DEFAULT_SALT_LENGTH];
        new SecureRandom().nextBytes(salt);

        return scrypt(password, salt, n, r, p, dkLen);
    }

    /**
     * Hash the supplied plaintext password and generate output in the format described
     * in {@link Scrypt#scrypt(String, int, int, int, int)}.
     *
     * @param password password
     * @param salt     the raw salt (16 bytes)
     * @param n        CPU cost parameter
     * @param r        memory cost parameter
     * @param p        parallelization parameter
     * @param dkLen    the desired key length in bits
     * @return the hashed password
     */
    public static String scrypt(String password, byte[] salt, int n, int r, int p, int dkLen) {
        try {
            byte[] derived = deriveScryptKey(password.getBytes(StandardCharsets.UTF_8), salt, n, r, p, dkLen);

            return formatHash(salt, n, r, p, derived);
        } catch (GeneralSecurityException e) {
            throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?");
        }
    }

    public static String formatSalt(byte[] salt, int n, int r, int p) {
        String params = encodeParams(n, r, p);

        StringBuilder sb = new StringBuilder((salt.length) * 2);
        sb.append("$s0$").append(params).append('$');
        sb.append(CipherUtility.encodeBase64NoPadding(salt));

        return sb.toString();
    }

    private static String encodeParams(int n, int r, int p) {
        return Long.toString(log2(n) << 16L | r << 8 | p, 16);
    }

    private static String formatHash(byte[] salt, int n, int r, int p, byte[] derived) {
        StringBuilder sb = new StringBuilder((salt.length + derived.length) * 2);
        sb.append(formatSalt(salt, n, r, p)).append('$');
        sb.append(CipherUtility.encodeBase64NoPadding(derived));

        return sb.toString();
    }

    /**
     * Returns the expected memory cost of the provided parameters in bytes.
     *
     * @param n the N value, iterations >= 2
     * @param r the r value, block size >= 1
     * @param p the p value, parallelization factor >= 1
     * @return the memory cost in bytes
     */
    public static int calculateExpectedMemory(int n, int r, int p) {
        return 128 * r * n + 128 * r * p;
    }

    /**
     * Compare the supplied plaintext password to a hashed password.
     *
     * @param password plaintext password
     * @param hashed   scrypt hashed password
     * @return true if password matches hashed value
     */
    public static boolean check(String password, String hashed) {
        try {
            if (StringUtils.isEmpty(password)) {
                throw new IllegalArgumentException("Password cannot be empty");
            }

            if (StringUtils.isEmpty(hashed)) {
                throw new IllegalArgumentException("Hash cannot be empty");
            }

            String[] parts = hashed.split("\\$");

            if (parts.length != 5 || !parts[1].equals("s0")) {
                throw new IllegalArgumentException("Hash is not properly formatted");
            }

            List<Integer> splitParams = parseParameters(parts[2]);
            int n = splitParams.get(0);
            int r = splitParams.get(1);
            int p = splitParams.get(2);

            byte[] salt = Base64.decodeBase64(parts[3]);
            byte[] derived0 = Base64.decodeBase64(parts[4]);

            // Previously this was hard-coded to 32 bits but the publicly-available scrypt methods accept arbitrary bit lengths
            int hashLength = derived0.length * 8;
            byte[] derived1 = deriveScryptKey(password.getBytes(StandardCharsets.UTF_8), salt, n, r, p, hashLength);

            if (derived0.length != derived1.length)
                return false;

            int result = 0;
            for (int i = 0; i < derived0.length; i++) {
                result |= derived0[i] ^ derived1[i];
            }
            return result == 0;
        } catch (GeneralSecurityException e) {
            throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?");
        }
    }

    /**
     * Parses the individual values from the encoded params value in the modified-mcrypt format for the salt & hash.
     * <p/>
     * Example:
     * <p/>
     * Hash: $s0$e0801$epIxT/h6HbbwHaehFnh/bw$7H0vsXlY8UxxyW/BWx/9GuY7jEvGjT71GFd6O4SZND0
     * Params:   e0801
     * <p/>
     * N = 16384
     * r = 8
     * p = 1
     *
     * @param encodedParams the String representation of the second section of the mcrypt format hash
     * @return a list containing N, r, p
     */
    public static List<Integer> parseParameters(String encodedParams) {
        long params = Long.parseLong(encodedParams, 16);

        List<Integer> paramsList = new ArrayList<>(3);

        // Parse N, r, p from encoded value and add to return list
        paramsList.add((int) Math.pow(2, params >> 16 & 0xffff));
        paramsList.add((int) params >> 8 & 0xff);
        paramsList.add((int) params & 0xff);

        return paramsList;
    }

    private static int log2(int n) {
        int log = 0;
        if ((n & 0xffff0000) != 0) {
            n >>>= 16;
            log = 16;
        }
        if (n >= 256) {
            n >>>= 8;
            log += 8;
        }
        if (n >= 16) {
            n >>>= 4;
            log += 4;
        }
        if (n >= 4) {
            n >>>= 2;
            log += 2;
        }
        return log + (n >>> 1);
    }

    /**
     * Implementation of the <a href="http://www.tarsnap.com/scrypt/scrypt.pdf">scrypt KDF</a>.
     *
     * @param password password
     * @param salt     salt
     * @param n        CPU cost parameter
     * @param r        memory cost parameter
     * @param p        parallelization parameter
     * @param dkLen    intended length of the derived key in bits
     * @return the derived key
     * @throws GeneralSecurityException when HMAC_SHA256 is not available
     */
    protected static byte[] deriveScryptKey(byte[] password, byte[] salt, int n, int r, int p, int dkLen)
            throws GeneralSecurityException {
        if (n < 2 || (n & (n - 1)) != 0) {
            throw new IllegalArgumentException("N must be a power of 2 greater than 1");
        }

        if (r < 1) {
            throw new IllegalArgumentException("Parameter r must be 1 or greater");
        }

        if (p < 1) {
            throw new IllegalArgumentException("Parameter p must be 1 or greater");
        }

        if (n > MAX_VALUE / 128 / r) {
            throw new IllegalArgumentException("Parameter N is too large");
        }

        // Must be enforced before r check
        if (p > MAX_VALUE / 128) {
            throw new IllegalArgumentException("Parameter p is too large");
        }

        if (r > MAX_VALUE / 128 / p) {
            throw new IllegalArgumentException("Parameter r is too large");
        }

        if (password == null || password.length == 0) {
            throw new IllegalArgumentException("Password cannot be empty");
        }

        int saltLength = salt == null ? 0 : salt.length;
        if (salt == null || saltLength == 0) {
            // Do not enforce this check here. According to the scrypt spec, the salt can be empty. However, in the user-facing ScryptCipherProvider, enforce an arbitrary check to avoid empty salts
            logger.warn("An empty salt was used for scrypt key derivation");
            //            throw new IllegalArgumentException("Salt cannot be empty");
            // as the Exception is not being thrown, prevent NPE if salt is null by setting it to empty array
            if (salt == null)
                salt = new byte[] {};
        }

        if (saltLength < 8 || saltLength > 32) {
            // Do not enforce this check here. According to the scrypt spec, the salt can be empty. However, in the user-facing ScryptCipherProvider, enforce an arbitrary check of [8..32] bytes
            logger.warn("A salt of length {} was used for scrypt key derivation", saltLength);
            //            throw new IllegalArgumentException("Salt must be between 8 and 32 bytes");
        }

        Mac mac = Mac.getInstance("HmacSHA256");
        mac.init(new SecretKeySpec(password, "HmacSHA256"));

        byte[] b = new byte[128 * r * p];
        byte[] xy = new byte[256 * r];
        byte[] v = new byte[128 * r * n];
        int i;

        pbkdf2(mac, salt, 1, b, p * 128 * r);

        for (i = 0; i < p; i++) {
            smix(b, i * 128 * r, r, n, v, xy);
        }

        byte[] dk = new byte[dkLen / 8];
        pbkdf2(mac, b, 1, dk, dkLen / 8);
        return dk;
    }

    /**
     * Implementation of PBKDF2 (RFC2898).
     *
     * @param alg   the HMAC algorithm to use
     * @param p     the password
     * @param s     the salt
     * @param c     the iteration count
     * @param dkLen the intended length, in octets, of the derived key
     * @return The derived key
     */
    private static byte[] pbkdf2(String alg, byte[] p, byte[] s, int c, int dkLen) throws GeneralSecurityException {
        Mac mac = Mac.getInstance(alg);
        mac.init(new SecretKeySpec(p, alg));
        byte[] dk = new byte[dkLen];
        pbkdf2(mac, s, c, dk, dkLen);
        return dk;
    }

    /**
     * Implementation of PBKDF2 (RFC2898).
     *
     * @param mac   the pre-initialized {@link Mac} instance to use
     * @param s     the salt
     * @param c     the iteration count
     * @param dk    the byte array that derived key will be placed in
     * @param dkLen the intended length, in octets, of the derived key
     * @throws GeneralSecurityException if the key length is too long
     */
    private static void pbkdf2(Mac mac, byte[] s, int c, byte[] dk, int dkLen) throws GeneralSecurityException {
        int hLen = mac.getMacLength();

        if (dkLen > (Math.pow(2, 32) - 1) * hLen) {
            throw new GeneralSecurityException("Requested key length too long");
        }

        byte[] U = new byte[hLen];
        byte[] T = new byte[hLen];
        byte[] block1 = new byte[s.length + 4];

        int l = (int) Math.ceil((double) dkLen / hLen);
        int r = dkLen - (l - 1) * hLen;

        arraycopy(s, 0, block1, 0, s.length);

        for (int i = 1; i <= l; i++) {
            block1[s.length + 0] = (byte) (i >> 24 & 0xff);
            block1[s.length + 1] = (byte) (i >> 16 & 0xff);
            block1[s.length + 2] = (byte) (i >> 8 & 0xff);
            block1[s.length + 3] = (byte) (i >> 0 & 0xff);

            mac.update(block1);
            mac.doFinal(U, 0);
            arraycopy(U, 0, T, 0, hLen);

            for (int j = 1; j < c; j++) {
                mac.update(U);
                mac.doFinal(U, 0);

                for (int k = 0; k < hLen; k++) {
                    T[k] ^= U[k];
                }
            }

            arraycopy(T, 0, dk, (i - 1) * hLen, (i == l ? r : hLen));
        }
    }

    private static void smix(byte[] b, int bi, int r, int n, byte[] v, byte[] xy) {
        int xi = 0;
        int yi = 128 * r;
        int i;

        arraycopy(b, bi, xy, xi, 128 * r);

        for (i = 0; i < n; i++) {
            arraycopy(xy, xi, v, i * (128 * r), 128 * r);
            blockmix_salsa8(xy, xi, yi, r);
        }

        for (i = 0; i < n; i++) {
            int j = integerify(xy, xi, r) & (n - 1);
            blockxor(v, j * (128 * r), xy, xi, 128 * r);
            blockmix_salsa8(xy, xi, yi, r);
        }

        arraycopy(xy, xi, b, bi, 128 * r);
    }

    private static void blockmix_salsa8(byte[] by, int bi, int yi, int r) {
        byte[] X = new byte[64];
        int i;

        arraycopy(by, bi + (2 * r - 1) * 64, X, 0, 64);

        for (i = 0; i < 2 * r; i++) {
            blockxor(by, i * 64, X, 0, 64);
            salsa20_8(X);
            arraycopy(X, 0, by, yi + (i * 64), 64);
        }

        for (i = 0; i < r; i++) {
            arraycopy(by, yi + (i * 2) * 64, by, bi + (i * 64), 64);
        }

        for (i = 0; i < r; i++) {
            arraycopy(by, yi + (i * 2 + 1) * 64, by, bi + (i + r) * 64, 64);
        }
    }

    private static int r(int a, int b) {
        return (a << b) | (a >>> (32 - b));
    }

    private static void salsa20_8(byte[] b) {
        int[] b32 = new int[16];
        int[] x = new int[16];
        int i;

        for (i = 0; i < 16; i++) {
            b32[i] = (b[i * 4 + 0] & 0xff) << 0;
            b32[i] |= (b[i * 4 + 1] & 0xff) << 8;
            b32[i] |= (b[i * 4 + 2] & 0xff) << 16;
            b32[i] |= (b[i * 4 + 3] & 0xff) << 24;
        }

        arraycopy(b32, 0, x, 0, 16);

        for (i = 8; i > 0; i -= 2) {
            x[4] ^= r(x[0] + x[12], 7);
            x[8] ^= r(x[4] + x[0], 9);
            x[12] ^= r(x[8] + x[4], 13);
            x[0] ^= r(x[12] + x[8], 18);
            x[9] ^= r(x[5] + x[1], 7);
            x[13] ^= r(x[9] + x[5], 9);
            x[1] ^= r(x[13] + x[9], 13);
            x[5] ^= r(x[1] + x[13], 18);
            x[14] ^= r(x[10] + x[6], 7);
            x[2] ^= r(x[14] + x[10], 9);
            x[6] ^= r(x[2] + x[14], 13);
            x[10] ^= r(x[6] + x[2], 18);
            x[3] ^= r(x[15] + x[11], 7);
            x[7] ^= r(x[3] + x[15], 9);
            x[11] ^= r(x[7] + x[3], 13);
            x[15] ^= r(x[11] + x[7], 18);
            x[1] ^= r(x[0] + x[3], 7);
            x[2] ^= r(x[1] + x[0], 9);
            x[3] ^= r(x[2] + x[1], 13);
            x[0] ^= r(x[3] + x[2], 18);
            x[6] ^= r(x[5] + x[4], 7);
            x[7] ^= r(x[6] + x[5], 9);
            x[4] ^= r(x[7] + x[6], 13);
            x[5] ^= r(x[4] + x[7], 18);
            x[11] ^= r(x[10] + x[9], 7);
            x[8] ^= r(x[11] + x[10], 9);
            x[9] ^= r(x[8] + x[11], 13);
            x[10] ^= r(x[9] + x[8], 18);
            x[12] ^= r(x[15] + x[14], 7);
            x[13] ^= r(x[12] + x[15], 9);
            x[14] ^= r(x[13] + x[12], 13);
            x[15] ^= r(x[14] + x[13], 18);
        }

        for (i = 0; i < 16; ++i)
            b32[i] = x[i] + b32[i];

        for (i = 0; i < 16; i++) {
            b[i * 4 + 0] = (byte) (b32[i] >> 0 & 0xff);
            b[i * 4 + 1] = (byte) (b32[i] >> 8 & 0xff);
            b[i * 4 + 2] = (byte) (b32[i] >> 16 & 0xff);
            b[i * 4 + 3] = (byte) (b32[i] >> 24 & 0xff);
        }
    }

    private static void blockxor(byte[] s, int si, byte[] d, int di, int len) {
        for (int i = 0; i < len; i++) {
            d[di + i] ^= s[si + i];
        }
    }

    private static int integerify(byte[] b, int bi, int r) {
        int n;

        bi += (2 * r - 1) * 64;

        n = (b[bi + 0] & 0xff) << 0;
        n |= (b[bi + 1] & 0xff) << 8;
        n |= (b[bi + 2] & 0xff) << 16;
        n |= (b[bi + 3] & 0xff) << 24;

        return n;
    }

    public static int getDefaultSaltLength() {
        return DEFAULT_SALT_LENGTH;
    }
}