Java tutorial
/* * Copyright 2014-2015 Amazon Technologies, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at: * * http://aws.amazon.com/apache2.0 * * This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and * limitations under the License. */ package com.amazonaws.http.conn.ssl; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.Socket; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Arrays; import java.util.Enumeration; import java.util.List; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSessionContext; import javax.net.ssl.SSLSocket; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.http.annotation.ThreadSafe; import org.apache.http.conn.ConnectTimeoutException; import org.apache.http.conn.ssl.SSLSocketFactory; import org.apache.http.conn.ssl.X509HostnameVerifier; import org.apache.http.params.HttpParams; import com.amazonaws.internal.SdkSSLSocket; import com.amazonaws.internal.SdkSocket; /** * Used to enforce the preferred TLS protocol during SSL handshake. */ @ThreadSafe public class SdkTLSSocketFactory extends SSLSocketFactory { private static final Log log = LogFactory.getLog(SdkTLSSocketFactory.class); private final SSLContext sslContext; public SdkTLSSocketFactory(final SSLContext sslContext, final X509HostnameVerifier hostnameVerifier) { super(sslContext, hostnameVerifier); if (sslContext == null) { throw new NullPointerException( "sslContext must not be null. " + "Use SSLContext.getDefault() if you are unsure."); } this.sslContext = sslContext; } /** * {@inheritDoc} * * Used to enforce the preferred TLS protocol during SSL handshake. */ @Override protected final void prepareSocket(final SSLSocket socket) { String[] supported = socket.getSupportedProtocols(); String[] enabled = socket.getEnabledProtocols(); if (log.isDebugEnabled()) { log.debug("socket.getSupportedProtocols(): " + Arrays.toString(supported) + ", socket.getEnabledProtocols(): " + Arrays.toString(enabled)); } List<String> target = new ArrayList<String>(); if (supported != null) { // Append the preferred protocols in descending order of preference // but only do so if the protocols are supported TLSProtocol[] values = TLSProtocol.values(); for (int i = 0; i < values.length; i++) { final String pname = values[i].getProtocolName(); if (existsIn(pname, supported)) target.add(pname); } } if (enabled != null) { // Append the rest of the already enabled protocols to the end // if not already included in the list for (String pname : enabled) { if (!target.contains(pname)) target.add(pname); } } if (target.size() > 0) { String[] enabling = target.toArray(new String[target.size()]); socket.setEnabledProtocols(enabling); if (log.isDebugEnabled()) { log.debug("TLS protocol enabled for SSL handshake: " + Arrays.toString(enabling)); } } } /** * Returns true if the given element exists in the given array; * false otherwise. */ private boolean existsIn(String element, String[] a) { for (String s : a) { if (element.equals(s)) return true; } return false; } @Override public Socket connectSocket(final Socket socket, final InetSocketAddress remoteAddress, final InetSocketAddress localAddress, final HttpParams params) throws IOException, UnknownHostException, ConnectTimeoutException { if (log.isDebugEnabled()) log.debug("connecting to " + remoteAddress.getAddress() + ":" + remoteAddress.getPort()); try { verifyMasterSecret(super.connectSocket(socket, remoteAddress, localAddress, params)); } catch (final SSLException sslEx) { // clear any related sessions from our cache if (log.isDebugEnabled()) { log.debug("connection failed due to SSL error, clearing TLS session cache", sslEx); } clearSessionCache(sslContext.getClientSessionContext(), remoteAddress); throw sslEx; } if (socket instanceof SSLSocket) return new SdkSSLSocket((SSLSocket) socket); return new SdkSocket(socket); } /** * Invalidates all SSL/TLS sessions in {@code sessionContext} associated with {@code remoteAddress}. * @param sessionContext collection of SSL/TLS sessions to be (potentially) invalidated * @param remoteAddress associated with sessions to invalidate */ private void clearSessionCache(final SSLSessionContext sessionContext, final InetSocketAddress remoteAddress) { final String hostName = remoteAddress.getHostName(); final int port = remoteAddress.getPort(); final Enumeration<byte[]> ids = sessionContext.getIds(); if (ids == null) { return; } while (ids.hasMoreElements()) { final byte[] id = ids.nextElement(); final SSLSession session = sessionContext.getSession(id); if (session != null && session.getPeerHost() != null && session.getPeerHost().equalsIgnoreCase(hostName) && session.getPeerPort() == port) { session.invalidate(); if (log.isDebugEnabled()) { log.debug("Invalidated session " + session); } } } } /** * Double check the master secret of an SSL session must not be null, or * else a {@link SecurityException} will be thrown. * @param sock connected socket */ private void verifyMasterSecret(final Socket sock) { if (sock instanceof SSLSocket) { SSLSocket ssl = (SSLSocket) sock; SSLSession session = ssl.getSession(); if (session != null) { String className = session.getClass().getName(); if ("sun.security.ssl.SSLSessionImpl".equals(className)) { try { Class<?> clazz = Class.forName(className); Method method = clazz.getDeclaredMethod("getMasterSecret"); method.setAccessible(true); Object masterSecret = method.invoke(session); if (masterSecret == null) { session.invalidate(); if (log.isDebugEnabled()) { log.debug("Invalidated session " + session); } throw log(new SecurityException("Invalid SSL master secret")); } } catch (ClassNotFoundException e) { failedToVerifyMasterSecret(e); } catch (NoSuchMethodException e) { failedToVerifyMasterSecret(e); } catch (IllegalAccessException e) { failedToVerifyMasterSecret(e); } catch (InvocationTargetException e) { failedToVerifyMasterSecret(e.getCause()); } } } } return; } private void failedToVerifyMasterSecret(Throwable t) { if (log.isDebugEnabled()) log.debug("Failed to verify the SSL master secret", t); } private <T extends Throwable> T log(T t) { if (log.isDebugEnabled()) log.debug("", t); return t; } }