Java tutorial
/** * Copyright 2012 Comcast Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.comcast.cmb.common.util; import org.apache.commons.codec.binary.Base64; import org.apache.http.NameValuePair; import org.apache.http.client.utils.URLEncodedUtils; import org.apache.http.message.BasicNameValuePair; import org.apache.log4j.Logger; import com.amazonaws.AmazonClientException; import com.amazonaws.auth.SigningAlgorithm; import com.amazonaws.util.BinaryUtils; import com.amazonaws.util.HttpUtils; import com.comcast.cqs.controller.CQSHttpServletRequest; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import javax.servlet.http.HttpServletRequest; import java.io.UnsupportedEncodingException; import java.math.BigInteger; import java.net.URL; import java.net.URLEncoder; import java.security.MessageDigest; import java.security.SecureRandom; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.*; import java.util.Map.Entry; /** * Utility functions for authentication * @author michael, bwolf * */ public class AuthUtil { private static final Logger logger = Logger.getLogger(AuthUtil.class); private static final int REQUEST_VALIDITY_PERIOD_MS = 900000; //15 mins private static final Random rand = new SecureRandom(); protected static final String DEFAULT_ENCODING = "UTF-8"; public static String hashPassword(String password) throws Exception { MessageDigest digest = MessageDigest.getInstance("MD5"); String salt = getRandomString(4, SECRET_CHARS); String toBeHashed = salt + password; byte[] hashed = digest.digest(toBeHashed.getBytes("UTF-8")); StringBuilder sb = new StringBuilder(hashed.length * 2 + 8); byte[] saltBytes = salt.getBytes(); for (int i = 0; i < 4; i++) { String hex = Integer.toHexString(0xFF & saltBytes[i]); if (hex.length() == 1) { sb.append('0'); } sb.append(hex); } for (int i = 0; i < hashed.length; i++) { String hex = Integer.toHexString(0xFF & hashed[i]); if (hex.length() == 1) { sb.append('0'); } sb.append(hex); } return sb.toString(); } public static boolean verifyPassword(String password, String hashedPassword) throws Exception { MessageDigest digest = MessageDigest.getInstance("MD5"); byte[] hashedBytes = new BigInteger(hashedPassword, 16).toByteArray(); String salt = new String(hashedBytes, 0, 4); String toBeHashed = salt + password; byte[] hashed = digest.digest(toBeHashed.getBytes("UTF-8")); for (int i = 0; i < hashed.length; i++) { if (hashed[i] != hashedBytes[i + 4]) { return false; } } return true; } private static final String KEY_CHARS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; private static final String SECRET_CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ+/"; public static String generateRandomAccessKey() { return getRandomString(20, KEY_CHARS); } public static String generateRandomAccessSecret() { return getRandomString(40, SECRET_CHARS); } private static String getRandomString(int len, final String validChars) { StringBuilder sb = new StringBuilder(len); for (int i = 0; i < len; i++) { sb.append(validChars.charAt(rand.nextInt(validChars.length()))); } return sb.toString(); } public static void checkTimeStamp(String ts) throws AuthenticationException { checkTimeStampWithFormat(ts, "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); } public static void checkTimeStampV4(String ts) throws AuthenticationException { checkTimeStampWithFormat(ts, "yyyyMMdd'T'HHmmss'Z'"); } public static void checkTimeStampWithFormat(String ts, String format) throws AuthenticationException { SimpleDateFormat dateFormat = new SimpleDateFormat(format); dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); Date timeStamp; try { timeStamp = dateFormat.parse(ts); } catch (ParseException ex) { logger.error("event=checking_timestamp timestamp=" + ts + " error_code=invalid_format", ex); throw new AuthenticationException(CMBErrorCodes.InvalidParameterValue, "Timestamp=" + ts + " is not valid"); } Date now = new Date(); if (now.getTime() - REQUEST_VALIDITY_PERIOD_MS < timeStamp.getTime() && now.getTime() + REQUEST_VALIDITY_PERIOD_MS > timeStamp.getTime()) { return; } logger.error("event=checking_timestamp timestamp=" + ts + " serverTime=" + dateFormat.format(now) + " error_code=timestamp_out_of_range"); throw new AuthenticationException(CMBErrorCodes.RequestExpired, "Request timestamp " + ts + " must be within 900 seconds of the server time"); } public static void checkExpiration(String expiration) throws AuthenticationException { SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"); dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); Date timeStamp; try { timeStamp = dateFormat.parse(expiration); } catch (ParseException e) { logger.error("event=checking_expiration expiration=" + expiration + " error_code=invalid_format", e); throw new AuthenticationException(CMBErrorCodes.InvalidParameterValue, "Expiration " + expiration + " is not valid"); } Date now = new Date(); if (now.getTime() < timeStamp.getTime()) { return; } logger.error("event=checking_timestamp expiration=" + expiration + " server_time=" + dateFormat.format(now) + " error_code=request_expired"); throw new AuthenticationException(CMBErrorCodes.RequestExpired, "Request with expiration " + expiration + " already expired"); } public static String generateSignature(URL url, Map<String, String> parameters, String version, String algorithm, String accessSecret) throws Exception { String data = null; if (version.equals("1")) { data = constructV1DataToSign(parameters); } else if (version.equals("2")) { parameters.put("SignatureMethod", algorithm); data = constructV2DataToSign(url, parameters); } else { return null; } Mac mac = Mac.getInstance(algorithm); mac.init(new SecretKeySpec(accessSecret.getBytes("UTF-8"), algorithm)); byte[] bytes = mac.doFinal(data.getBytes("UTF-8")); String signature = new String(Base64.encodeBase64(bytes)); return signature; } public static String generateSignatureV4(HttpServletRequest request, URL url, Map<String, String> parameters, Map<String, String> headers, String version, String algorithm, String accessSecret) throws Exception { /* Example of authorization header value * AWS4-HMAC-SHA256 Credential=XK1MWJAYYGQ41ECH06WG/20131126/us-east-1/us-east-1/aws4_request, SignedHeaders=host;user-agent;x-amz-date, Signature=18541c4db00d098414c0bae7394450d1deada902699a45de02849dbcb336f9e3 */ String authorizationHeader = request.getHeader("authorization"); String credentialPart = authorizationHeader .substring(authorizationHeader.indexOf("Credential=") + "Credential=".length()); String[] credentialPartArray = credentialPart.split("/"); String regionName = credentialPartArray[2]; String serviceName = credentialPartArray[3]; String dateTime = request.getHeader("X-Amz-Date"); String dateStamp = credentialPartArray[1]; String scope = credentialPart.substring(credentialPart.indexOf("/") + 1, credentialPart.indexOf(",")); String payloadString = getPayload(request); String contentSha256 = BinaryUtils.toHex(hash(payloadString)); Map<String, String> filteredHeaders = filterHeader(headers); String stringToSign = getStringToSign("AWS4-" + algorithm, dateTime, scope, getCanonicalRequest(request, contentSha256, parameters, filteredHeaders)); byte[] secret = ("AWS4" + accessSecret).getBytes(); byte[] date = sign(dateStamp, secret, SigningAlgorithm.HmacSHA256); byte[] region = sign(regionName, date, SigningAlgorithm.HmacSHA256); byte[] service = sign(serviceName, region, SigningAlgorithm.HmacSHA256); byte[] signing = sign("aws4_request", service, SigningAlgorithm.HmacSHA256); byte[] signatureBytes = sign(stringToSign.getBytes(), signing, SigningAlgorithm.HmacSHA256); String signature = BinaryUtils.toHex(signatureBytes); return signature; } public static byte[] sign(String stringData, byte[] key, SigningAlgorithm algorithm) throws AmazonClientException { try { byte[] data = stringData.getBytes("UTF-8"); return sign(data, key, algorithm); } catch (Exception e) { throw new AmazonClientException("Unable to calculate a request signature: " + e.getMessage(), e); } } protected static byte[] sign(byte[] data, byte[] key, SigningAlgorithm algorithm) throws AmazonClientException { try { Mac mac = Mac.getInstance(algorithm.toString()); mac.init(new SecretKeySpec(key, algorithm.toString())); return mac.doFinal(data); } catch (Exception e) { throw new AmazonClientException("Unable to calculate a request signature: " + e.getMessage(), e); } } protected static String getStringToSign(String algorithm, String dateTime, String scope, String canonicalRequest) { String stringToSign = algorithm + "\n" + dateTime + "\n" + scope + "\n" + BinaryUtils.toHex(hash(canonicalRequest)); logger.debug("AWS4 String to Sign: '\"" + stringToSign + "\""); return stringToSign; } public static byte[] hash(String text) throws AmazonClientException { try { MessageDigest md = MessageDigest.getInstance("SHA-256"); md.update(text.getBytes("UTF-8")); return md.digest(); } catch (Exception e) { throw new AmazonClientException("Unable to compute hash while signing request: " + e.getMessage(), e); } } protected static String getCanonicalRequest(HttpServletRequest request, String contentSha256, Map<String, String> parameters, Map<String, String> headers) { String canonicalRequest = null; canonicalRequest = request.getMethod() + "\n" + getResourcePath(request) + "\n" + getCanonicalizedQueryString(request, parameters) + "\n" + getCanonicalizedHeaderString(headers) + "\n" + getSignedHeadersString(headers) + "\n" + contentSha256; logger.debug("AWS4 Canonical Request: '\"" + canonicalRequest + "\""); return canonicalRequest; } protected static String getCanonicalizedQueryString(HttpServletRequest request, Map<String, String> parameters) { return ""; } private static String getResourcePath(HttpServletRequest request) { String path = request.getRequestURI(); return path; } private static String constructV1DataToSign(Map<String, String> parameters) { StringBuilder data = new StringBuilder(); SortedMap<String, String> sortedParameters = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER); sortedParameters.putAll(parameters); for (String key : sortedParameters.keySet()) { data.append(key); data.append(sortedParameters.get(key)); } return data.toString(); } private static Map<String, String> filterHeader(Map<String, String> headers) { Map<String, String> filteredHeaders = new HashMap<String, String>(); String authorizationString = headers.get("Authorization"); String singnedHeadersString = authorizationString.substring( authorizationString.indexOf("SignedHeaders=") + new String("SignedHeaders=").length(), authorizationString.indexOf(", Signature")); String[] headersArray = singnedHeadersString.split(";"); //dealing with lower case letter Map<String, String> lowerCaseHeaders = new HashMap<String, String>(); for (Entry<String, String> entry : headers.entrySet()) { lowerCaseHeaders.put(entry.getKey().toLowerCase(), entry.getKey()); } for (String currentHeaderName : headersArray) { if (lowerCaseHeaders.containsKey(currentHeaderName.trim())) { filteredHeaders.put(lowerCaseHeaders.get(currentHeaderName.trim()), headers.get(lowerCaseHeaders.get(currentHeaderName.trim()))); } } return filteredHeaders; } private static String constructV2DataToSign(URL url, Map<String, String> parameters) throws UnsupportedEncodingException { StringBuilder sb = new StringBuilder(); sb.append("POST").append("\n"); sb.append(normalizeURL(url)).append("\n"); sb.append(normalizeResourcePath(url.getPath())).append("\n"); sb.append(normalizeQueryString(parameters)); return sb.toString(); } private static String normalizeURL(URL url) { String normalizedUrl = url.getHost().toLowerCase(); // account for apache http client omitting standard ports if (url.getPort() > 0 && url.getPort() != 80 && url.getPort() != 443) { normalizedUrl += ":" + url.getPort(); } return normalizedUrl; } private static String normalizeResourcePath(String resourcePath) throws UnsupportedEncodingException { String normalizedResourcePath = null; if (resourcePath == null || resourcePath.length() == 0) { normalizedResourcePath = "/"; } else { normalizedResourcePath = urlEncode(resourcePath, true); } return normalizedResourcePath; } protected static String getCanonicalizedResourcePath(String resourcePath) { if (resourcePath == null || resourcePath.length() == 0) { return "/"; } else { String value = HttpUtils.urlEncode(resourcePath, true); if (value.startsWith("/")) { return value; } else { return "/".concat(value); } } } protected static String getCanonicalizedHeaderString(Map<String, String> headers) { List<String> sortedHeaders = new ArrayList<String>(); sortedHeaders.addAll(headers.keySet()); Collections.sort(sortedHeaders, String.CASE_INSENSITIVE_ORDER); StringBuilder buffer = new StringBuilder(); for (String header : sortedHeaders) { buffer.append(header.toLowerCase().replaceAll("\\s+", " ") + ":" + headers.get(header).replaceAll("\\s+", " ")); buffer.append("\n"); } return buffer.toString(); } protected static String getSignedHeadersString(Map<String, String> headers) { List<String> sortedHeaders = new ArrayList<String>(); sortedHeaders.addAll(headers.keySet()); Collections.sort(sortedHeaders, String.CASE_INSENSITIVE_ORDER); StringBuilder buffer = new StringBuilder(); for (String header : sortedHeaders) { if (buffer.length() > 0) buffer.append(";"); buffer.append(header.toLowerCase()); } return buffer.toString(); } private static String normalizeQueryString(Map<String, String> parameters) throws UnsupportedEncodingException { SortedMap<String, String> sorted = new TreeMap<String, String>(); sorted.putAll(parameters); StringBuilder builder = new StringBuilder(); Iterator<Map.Entry<String, String>> pairs = sorted.entrySet().iterator(); while (pairs.hasNext()) { Map.Entry<String, String> pair = pairs.next(); String key = pair.getKey(); String value = pair.getValue(); builder.append(urlEncode(key, false)); builder.append("="); builder.append(urlEncode(value, false)); if (pairs.hasNext()) { builder.append("&"); } } return builder.toString(); } private static String getPayload(HttpServletRequest reqeust) throws UnsupportedEncodingException { return encodeParameters(reqeust); } /** * Creates an encoded query string from all the parameters in the specified * request. * * @param request * The request containing the parameters to encode. * * @return Null if no parameters were present, otherwise the encoded query * string for the parameters present in the specified request. */ public static String encodeParameters(HttpServletRequest request) { CQSHttpServletRequest wrappedRequest = (CQSHttpServletRequest) request; List<NameValuePair> nameValuePairs = null; String parameterName = null; if (wrappedRequest.getPostParameterNames().hasMoreElements()) { nameValuePairs = new ArrayList<NameValuePair>(); while (wrappedRequest.getPostParameterNames().hasMoreElements()) { parameterName = wrappedRequest.getPostParameterNames().nextElement(); nameValuePairs .add(new BasicNameValuePair(parameterName, wrappedRequest.getPostParameter(parameterName))); } } String encodedParams = ""; if (nameValuePairs != null) { encodedParams = URLEncodedUtils.format(nameValuePairs, DEFAULT_ENCODING); } return encodedParams; } private static String urlEncode(String value, boolean isPath) throws UnsupportedEncodingException { if (value == null) { return ""; } String encoded = URLEncoder.encode(value, "UTF-8").replace("+", "%20").replace("*", "%2A").replace("%7E", "~"); if (isPath) { encoded = encoded.replace("%2F", "/"); } return encoded; } }