com.comcast.cmb.common.util.AuthUtil.java Source code

Java tutorial

Introduction

Here is the source code for com.comcast.cmb.common.util.AuthUtil.java

Source

/**
 * Copyright 2012 Comcast Corporation
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.comcast.cmb.common.util;

import org.apache.commons.codec.binary.Base64;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URLEncodedUtils;
import org.apache.http.message.BasicNameValuePair;
import org.apache.log4j.Logger;

import com.amazonaws.AmazonClientException;
import com.amazonaws.auth.SigningAlgorithm;
import com.amazonaws.util.BinaryUtils;
import com.amazonaws.util.HttpUtils;
import com.comcast.cqs.controller.CQSHttpServletRequest;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.servlet.http.HttpServletRequest;

import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.net.URL;
import java.net.URLEncoder;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.Map.Entry;

/**
 * Utility functions for authentication
 * @author michael, bwolf
 * 
 */
public class AuthUtil {

    private static final Logger logger = Logger.getLogger(AuthUtil.class);
    private static final int REQUEST_VALIDITY_PERIOD_MS = 900000; //15 mins
    private static final Random rand = new SecureRandom();
    protected static final String DEFAULT_ENCODING = "UTF-8";

    public static String hashPassword(String password) throws Exception {

        MessageDigest digest = MessageDigest.getInstance("MD5");

        String salt = getRandomString(4, SECRET_CHARS);
        String toBeHashed = salt + password;

        byte[] hashed = digest.digest(toBeHashed.getBytes("UTF-8"));

        StringBuilder sb = new StringBuilder(hashed.length * 2 + 8);

        byte[] saltBytes = salt.getBytes();

        for (int i = 0; i < 4; i++) {

            String hex = Integer.toHexString(0xFF & saltBytes[i]);

            if (hex.length() == 1) {
                sb.append('0');
            }

            sb.append(hex);
        }

        for (int i = 0; i < hashed.length; i++) {

            String hex = Integer.toHexString(0xFF & hashed[i]);

            if (hex.length() == 1) {
                sb.append('0');
            }

            sb.append(hex);
        }

        return sb.toString();
    }

    public static boolean verifyPassword(String password, String hashedPassword) throws Exception {

        MessageDigest digest = MessageDigest.getInstance("MD5");

        byte[] hashedBytes = new BigInteger(hashedPassword, 16).toByteArray();

        String salt = new String(hashedBytes, 0, 4);
        String toBeHashed = salt + password;

        byte[] hashed = digest.digest(toBeHashed.getBytes("UTF-8"));

        for (int i = 0; i < hashed.length; i++) {
            if (hashed[i] != hashedBytes[i + 4]) {
                return false;
            }
        }

        return true;
    }

    private static final String KEY_CHARS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    private static final String SECRET_CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ+/";

    public static String generateRandomAccessKey() {
        return getRandomString(20, KEY_CHARS);
    }

    public static String generateRandomAccessSecret() {
        return getRandomString(40, SECRET_CHARS);
    }

    private static String getRandomString(int len, final String validChars) {

        StringBuilder sb = new StringBuilder(len);

        for (int i = 0; i < len; i++) {
            sb.append(validChars.charAt(rand.nextInt(validChars.length())));
        }

        return sb.toString();
    }

    public static void checkTimeStamp(String ts) throws AuthenticationException {

        checkTimeStampWithFormat(ts, "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'");
    }

    public static void checkTimeStampV4(String ts) throws AuthenticationException {
        checkTimeStampWithFormat(ts, "yyyyMMdd'T'HHmmss'Z'");
    }

    public static void checkTimeStampWithFormat(String ts, String format) throws AuthenticationException {

        SimpleDateFormat dateFormat = new SimpleDateFormat(format);
        dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));

        Date timeStamp;

        try {
            timeStamp = dateFormat.parse(ts);
        } catch (ParseException ex) {
            logger.error("event=checking_timestamp timestamp=" + ts + " error_code=invalid_format", ex);
            throw new AuthenticationException(CMBErrorCodes.InvalidParameterValue,
                    "Timestamp=" + ts + " is not valid");
        }

        Date now = new Date();

        if (now.getTime() - REQUEST_VALIDITY_PERIOD_MS < timeStamp.getTime()
                && now.getTime() + REQUEST_VALIDITY_PERIOD_MS > timeStamp.getTime()) {
            return;
        }

        logger.error("event=checking_timestamp timestamp=" + ts + " serverTime=" + dateFormat.format(now)
                + " error_code=timestamp_out_of_range");
        throw new AuthenticationException(CMBErrorCodes.RequestExpired,
                "Request timestamp " + ts + " must be within 900 seconds of the server time");
    }

    public static void checkExpiration(String expiration) throws AuthenticationException {

        SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'");
        dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));

        Date timeStamp;

        try {
            timeStamp = dateFormat.parse(expiration);
        } catch (ParseException e) {
            logger.error("event=checking_expiration expiration=" + expiration + " error_code=invalid_format", e);
            throw new AuthenticationException(CMBErrorCodes.InvalidParameterValue,
                    "Expiration " + expiration + " is not valid");
        }

        Date now = new Date();

        if (now.getTime() < timeStamp.getTime()) {
            return;
        }

        logger.error("event=checking_timestamp expiration=" + expiration + " server_time=" + dateFormat.format(now)
                + " error_code=request_expired");
        throw new AuthenticationException(CMBErrorCodes.RequestExpired,
                "Request with expiration " + expiration + " already expired");
    }

    public static String generateSignature(URL url, Map<String, String> parameters, String version,
            String algorithm, String accessSecret) throws Exception {

        String data = null;

        if (version.equals("1")) {
            data = constructV1DataToSign(parameters);
        } else if (version.equals("2")) {
            parameters.put("SignatureMethod", algorithm);
            data = constructV2DataToSign(url, parameters);
        } else {
            return null;
        }

        Mac mac = Mac.getInstance(algorithm);
        mac.init(new SecretKeySpec(accessSecret.getBytes("UTF-8"), algorithm));
        byte[] bytes = mac.doFinal(data.getBytes("UTF-8"));
        String signature = new String(Base64.encodeBase64(bytes));

        return signature;
    }

    public static String generateSignatureV4(HttpServletRequest request, URL url, Map<String, String> parameters,
            Map<String, String> headers, String version, String algorithm, String accessSecret) throws Exception {

        /* Example of authorization header value
         * AWS4-HMAC-SHA256 Credential=XK1MWJAYYGQ41ECH06WG/20131126/us-east-1/us-east-1/aws4_request, SignedHeaders=host;user-agent;x-amz-date, Signature=18541c4db00d098414c0bae7394450d1deada902699a45de02849dbcb336f9e3
        */
        String authorizationHeader = request.getHeader("authorization");
        String credentialPart = authorizationHeader
                .substring(authorizationHeader.indexOf("Credential=") + "Credential=".length());
        String[] credentialPartArray = credentialPart.split("/");

        String regionName = credentialPartArray[2];
        String serviceName = credentialPartArray[3];

        String dateTime = request.getHeader("X-Amz-Date");
        String dateStamp = credentialPartArray[1];

        String scope = credentialPart.substring(credentialPart.indexOf("/") + 1, credentialPart.indexOf(","));

        String payloadString = getPayload(request);
        String contentSha256 = BinaryUtils.toHex(hash(payloadString));
        Map<String, String> filteredHeaders = filterHeader(headers);

        String stringToSign = getStringToSign("AWS4-" + algorithm, dateTime, scope,
                getCanonicalRequest(request, contentSha256, parameters, filteredHeaders));

        byte[] secret = ("AWS4" + accessSecret).getBytes();
        byte[] date = sign(dateStamp, secret, SigningAlgorithm.HmacSHA256);
        byte[] region = sign(regionName, date, SigningAlgorithm.HmacSHA256);
        byte[] service = sign(serviceName, region, SigningAlgorithm.HmacSHA256);
        byte[] signing = sign("aws4_request", service, SigningAlgorithm.HmacSHA256);

        byte[] signatureBytes = sign(stringToSign.getBytes(), signing, SigningAlgorithm.HmacSHA256);

        String signature = BinaryUtils.toHex(signatureBytes);

        return signature;
    }

    public static byte[] sign(String stringData, byte[] key, SigningAlgorithm algorithm)
            throws AmazonClientException {
        try {
            byte[] data = stringData.getBytes("UTF-8");
            return sign(data, key, algorithm);
        } catch (Exception e) {
            throw new AmazonClientException("Unable to calculate a request signature: " + e.getMessage(), e);
        }
    }

    protected static byte[] sign(byte[] data, byte[] key, SigningAlgorithm algorithm) throws AmazonClientException {
        try {
            Mac mac = Mac.getInstance(algorithm.toString());
            mac.init(new SecretKeySpec(key, algorithm.toString()));
            return mac.doFinal(data);
        } catch (Exception e) {
            throw new AmazonClientException("Unable to calculate a request signature: " + e.getMessage(), e);
        }
    }

    protected static String getStringToSign(String algorithm, String dateTime, String scope,
            String canonicalRequest) {
        String stringToSign = algorithm + "\n" + dateTime + "\n" + scope + "\n"
                + BinaryUtils.toHex(hash(canonicalRequest));
        logger.debug("AWS4 String to Sign: '\"" + stringToSign + "\"");
        return stringToSign;
    }

    public static byte[] hash(String text) throws AmazonClientException {
        try {
            MessageDigest md = MessageDigest.getInstance("SHA-256");
            md.update(text.getBytes("UTF-8"));
            return md.digest();
        } catch (Exception e) {
            throw new AmazonClientException("Unable to compute hash while signing request: " + e.getMessage(), e);
        }
    }

    protected static String getCanonicalRequest(HttpServletRequest request, String contentSha256,
            Map<String, String> parameters, Map<String, String> headers) {

        String canonicalRequest = null;
        canonicalRequest = request.getMethod() + "\n" + getResourcePath(request) + "\n"
                + getCanonicalizedQueryString(request, parameters) + "\n" + getCanonicalizedHeaderString(headers)
                + "\n" + getSignedHeadersString(headers) + "\n" + contentSha256;
        logger.debug("AWS4 Canonical Request: '\"" + canonicalRequest + "\"");

        return canonicalRequest;
    }

    protected static String getCanonicalizedQueryString(HttpServletRequest request,
            Map<String, String> parameters) {
        return "";
    }

    private static String getResourcePath(HttpServletRequest request) {
        String path = request.getRequestURI();
        return path;
    }

    private static String constructV1DataToSign(Map<String, String> parameters) {

        StringBuilder data = new StringBuilder();
        SortedMap<String, String> sortedParameters = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER);
        sortedParameters.putAll(parameters);

        for (String key : sortedParameters.keySet()) {
            data.append(key);
            data.append(sortedParameters.get(key));
        }

        return data.toString();
    }

    private static Map<String, String> filterHeader(Map<String, String> headers) {
        Map<String, String> filteredHeaders = new HashMap<String, String>();
        String authorizationString = headers.get("Authorization");
        String singnedHeadersString = authorizationString.substring(
                authorizationString.indexOf("SignedHeaders=") + new String("SignedHeaders=").length(),
                authorizationString.indexOf(", Signature"));
        String[] headersArray = singnedHeadersString.split(";");
        //dealing with lower case letter
        Map<String, String> lowerCaseHeaders = new HashMap<String, String>();
        for (Entry<String, String> entry : headers.entrySet()) {
            lowerCaseHeaders.put(entry.getKey().toLowerCase(), entry.getKey());
        }

        for (String currentHeaderName : headersArray) {
            if (lowerCaseHeaders.containsKey(currentHeaderName.trim())) {
                filteredHeaders.put(lowerCaseHeaders.get(currentHeaderName.trim()),
                        headers.get(lowerCaseHeaders.get(currentHeaderName.trim())));
            }
        }
        return filteredHeaders;
    }

    private static String constructV2DataToSign(URL url, Map<String, String> parameters)
            throws UnsupportedEncodingException {

        StringBuilder sb = new StringBuilder();
        sb.append("POST").append("\n");
        sb.append(normalizeURL(url)).append("\n");
        sb.append(normalizeResourcePath(url.getPath())).append("\n");
        sb.append(normalizeQueryString(parameters));

        return sb.toString();
    }

    private static String normalizeURL(URL url) {

        String normalizedUrl = url.getHost().toLowerCase();

        // account for apache http client omitting standard ports

        if (url.getPort() > 0 && url.getPort() != 80 && url.getPort() != 443) {
            normalizedUrl += ":" + url.getPort();
        }

        return normalizedUrl;
    }

    private static String normalizeResourcePath(String resourcePath) throws UnsupportedEncodingException {

        String normalizedResourcePath = null;

        if (resourcePath == null || resourcePath.length() == 0) {
            normalizedResourcePath = "/";
        } else {
            normalizedResourcePath = urlEncode(resourcePath, true);
        }

        return normalizedResourcePath;
    }

    protected static String getCanonicalizedResourcePath(String resourcePath) {
        if (resourcePath == null || resourcePath.length() == 0) {
            return "/";
        } else {
            String value = HttpUtils.urlEncode(resourcePath, true);
            if (value.startsWith("/")) {
                return value;
            } else {
                return "/".concat(value);
            }
        }
    }

    protected static String getCanonicalizedHeaderString(Map<String, String> headers) {
        List<String> sortedHeaders = new ArrayList<String>();
        sortedHeaders.addAll(headers.keySet());
        Collections.sort(sortedHeaders, String.CASE_INSENSITIVE_ORDER);

        StringBuilder buffer = new StringBuilder();
        for (String header : sortedHeaders) {
            buffer.append(header.toLowerCase().replaceAll("\\s+", " ") + ":"
                    + headers.get(header).replaceAll("\\s+", " "));
            buffer.append("\n");
        }

        return buffer.toString();
    }

    protected static String getSignedHeadersString(Map<String, String> headers) {
        List<String> sortedHeaders = new ArrayList<String>();
        sortedHeaders.addAll(headers.keySet());
        Collections.sort(sortedHeaders, String.CASE_INSENSITIVE_ORDER);

        StringBuilder buffer = new StringBuilder();
        for (String header : sortedHeaders) {
            if (buffer.length() > 0)
                buffer.append(";");
            buffer.append(header.toLowerCase());
        }

        return buffer.toString();
    }

    private static String normalizeQueryString(Map<String, String> parameters) throws UnsupportedEncodingException {

        SortedMap<String, String> sorted = new TreeMap<String, String>();
        sorted.putAll(parameters);

        StringBuilder builder = new StringBuilder();
        Iterator<Map.Entry<String, String>> pairs = sorted.entrySet().iterator();

        while (pairs.hasNext()) {

            Map.Entry<String, String> pair = pairs.next();
            String key = pair.getKey();
            String value = pair.getValue();
            builder.append(urlEncode(key, false));
            builder.append("=");
            builder.append(urlEncode(value, false));

            if (pairs.hasNext()) {
                builder.append("&");
            }
        }

        return builder.toString();
    }

    private static String getPayload(HttpServletRequest reqeust) throws UnsupportedEncodingException {

        return encodeParameters(reqeust);
    }

    /**
     * Creates an encoded query string from all the parameters in the specified
     * request.
     *
     * @param request
     *            The request containing the parameters to encode.
     *
     * @return Null if no parameters were present, otherwise the encoded query
     *         string for the parameters present in the specified request.
     */
    public static String encodeParameters(HttpServletRequest request) {
        CQSHttpServletRequest wrappedRequest = (CQSHttpServletRequest) request;
        List<NameValuePair> nameValuePairs = null;
        String parameterName = null;
        if (wrappedRequest.getPostParameterNames().hasMoreElements()) {
            nameValuePairs = new ArrayList<NameValuePair>();
            while (wrappedRequest.getPostParameterNames().hasMoreElements()) {
                parameterName = wrappedRequest.getPostParameterNames().nextElement();
                nameValuePairs
                        .add(new BasicNameValuePair(parameterName, wrappedRequest.getPostParameter(parameterName)));
            }
        }

        String encodedParams = "";
        if (nameValuePairs != null) {
            encodedParams = URLEncodedUtils.format(nameValuePairs, DEFAULT_ENCODING);
        }

        return encodedParams;
    }

    private static String urlEncode(String value, boolean isPath) throws UnsupportedEncodingException {

        if (value == null) {
            return "";
        }

        String encoded = URLEncoder.encode(value, "UTF-8").replace("+", "%20").replace("*", "%2A").replace("%7E",
                "~");

        if (isPath) {
            encoded = encoded.replace("%2F", "/");
        }

        return encoded;
    }
}