Java tutorial
/** * Copyright (c) Codice Foundation * * <p>This is free software: you can redistribute it and/or modify it under the terms of the GNU * Lesser General Public License as published by the Free Software Foundation, either version 3 of * the License, or any later version. * * <p>This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. A copy of the GNU Lesser General Public * License is distributed along with this program and can be found at * <http://www.gnu.org/licenses/lgpl.html>. */ package org.codice.ddf.cxf.paos; import com.google.api.client.http.ByteArrayContent; import com.google.api.client.http.GenericUrl; import com.google.api.client.http.HttpContent; import com.google.api.client.http.HttpRequest; import com.google.api.client.http.HttpResponse; import com.google.api.client.http.HttpStatusCodes; import com.google.api.client.http.HttpTransport; import com.google.api.client.http.HttpUnsuccessfulResponseHandler; import com.google.api.client.http.InputStreamContent; import com.google.api.client.http.javanet.NetHttpTransport; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import com.google.common.io.ByteSource; import ddf.security.liberty.paos.Response; import ddf.security.liberty.paos.impl.ResponseBuilder; import ddf.security.samlp.SamlProtocol; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.stream.Collectors; import javax.xml.soap.SOAPException; import javax.xml.soap.SOAPHeaderElement; import javax.xml.soap.SOAPPart; import javax.xml.stream.XMLStreamException; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.apache.cxf.helpers.DOMUtils; import org.apache.cxf.interceptor.Fault; import org.apache.cxf.interceptor.security.AccessDeniedException; import org.apache.cxf.message.Message; import org.apache.cxf.phase.AbstractPhaseInterceptor; import org.apache.wss4j.common.ext.WSSecurityException; import org.apache.wss4j.common.saml.OpenSAMLUtil; import org.apache.wss4j.common.util.DOM2Writer; import org.codice.ddf.platform.util.TemporaryFileBackedOutputStream; import org.codice.ddf.security.common.jaxrs.RestSecurity; import org.opensaml.core.xml.XMLObject; import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.IDPEntry; import org.opensaml.saml.saml2.core.IDPList; import org.opensaml.saml.saml2.ecp.Request; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.Node; public class PaosInInterceptor extends AbstractPhaseInterceptor<Message> { public static final Logger LOGGER = LoggerFactory.getLogger(PaosInInterceptor.class); public static final String RELAY_STATE = "RelayState"; public static final String REQUEST = "Request"; public static final String RESPONSE = "Response"; public static final String ASSERTION_CONSUMER_SERVICE_URL = "AssertionConsumerServiceURL"; public static final String RESPONSE_CONSUMER_URL = "responseConsumerURL"; public static final String URN_OASIS_NAMES_TC_SAML_2_0_PROFILES_SSO_ECP = "urn:oasis:names:tc:SAML:2.0:profiles:SSO:ecp"; public static final String MESSAGE_ID = "messageID"; public static final String ECP_RESPONSE = "ecp:Response"; public static final String BASIC = "BASIC"; public static final String SAML = "SAML"; public static final String TEXT_XML = "text/xml"; public static final String SOAP_ACTION = "SOAPAction"; public static final String HTTP_WWW_OASIS_OPEN_ORG_COMMITTEES_SECURITY = "http://www.oasis-open.org/committees/security"; public static final String URN_LIBERTY_PAOS_2003_08 = "urn:liberty:paos:2003-08"; public static final String APPLICATION_VND_PAOS_XML = "application/vnd.paos+xml"; private String soapMessage; private String soapfaultMessage; private String securityHeader; private String usernameToken; public PaosInInterceptor(String phase) { super(phase); try (InputStream soapMessageStream = PaosInInterceptor.class .getResourceAsStream("/templates/soap.handlebars"); InputStream soapfaultMessageStream = PaosInInterceptor.class .getResourceAsStream("/templates/soapfault.handlebars"); InputStream securityHeaderStream = PaosInInterceptor.class .getResourceAsStream("/templates/security.handlebars"); InputStream userTokenStream = PaosInInterceptor.class .getResourceAsStream("/templates/username.handlebars")) { soapMessage = IOUtils.toString(soapMessageStream); soapfaultMessage = IOUtils.toString(soapfaultMessageStream); securityHeader = IOUtils.toString(securityHeaderStream); usernameToken = IOUtils.toString(userTokenStream); } catch (IOException e) { LOGGER.info("Unable to load templates for PAOS"); } } @Override public void handleMessage(Message message) throws Fault { List authHeader = (List) ((Map) message.getExchange().getOutMessage().get(Message.PROTOCOL_HEADERS)) .get("Authorization"); String authorization = null; if (authHeader != null && authHeader.size() > 0) { authorization = (String) authHeader.get(0); } InputStream content = message.getContent(InputStream.class); String contentType = (String) message.get(Message.CONTENT_TYPE); if (contentType == null || !contentType.contains(APPLICATION_VND_PAOS_XML)) { return; } try { SOAPPart soapMessage = SamlProtocol .parseSoapMessage(IOUtils.toString(content, Charset.forName("UTF-8"))); Iterator iterator = soapMessage.getEnvelope().getHeader().examineAllHeaderElements(); IDPEntry idpEntry = null; String relayState = ""; String responseConsumerURL = ""; String messageId = ""; while (iterator.hasNext()) { Element soapHeaderElement = (SOAPHeaderElement) iterator.next(); if (RELAY_STATE.equals(soapHeaderElement.getLocalName())) { relayState = DOM2Writer.nodeToString(soapHeaderElement); } else if (REQUEST.equals(soapHeaderElement.getLocalName()) && soapHeaderElement.getNamespaceURI() .equals(URN_OASIS_NAMES_TC_SAML_2_0_PROFILES_SSO_ECP)) { try { soapHeaderElement = SamlProtocol.convertDomImplementation(soapHeaderElement); Request ecpRequest = (Request) OpenSAMLUtil.fromDom(soapHeaderElement); IDPList idpList = ecpRequest.getIDPList(); if (idpList == null) { throw new Fault(new AccessDeniedException( "Unable to complete SAML ECP connection. Unable to determine IdP server.")); } List<IDPEntry> idpEntrys = idpList.getIDPEntrys(); if (idpEntrys == null || idpEntrys.size() == 0) { throw new Fault(new AccessDeniedException( "Unable to complete SAML ECP connection. Unable to determine IdP server.")); } // choose the right entry, probably need to do something better than select the first // one // but the spec doesn't specify how this is supposed to be done idpEntry = idpEntrys.get(0); } catch (WSSecurityException e) { // TODO figure out IdP alternatively LOGGER.info( "Unable to determine IdP appropriately. ECP connection will fail. SP may be incorrectly configured. Contact the administrator for the remote system."); } } else if (REQUEST.equals(soapHeaderElement.getLocalName()) && soapHeaderElement.getNamespaceURI().equals(URN_LIBERTY_PAOS_2003_08)) { responseConsumerURL = soapHeaderElement.getAttribute(RESPONSE_CONSUMER_URL); messageId = soapHeaderElement.getAttribute(MESSAGE_ID); } } if (idpEntry == null) { throw new Fault(new AccessDeniedException( "Unable to complete SAML ECP connection. Unable to determine IdP server.")); } String token = createToken(authorization); checkAuthnRequest(soapMessage); Element authnRequestElement = SamlProtocol .getDomElement(soapMessage.getEnvelope().getBody().getFirstChild()); String loc = idpEntry.getLoc(); String soapRequest = buildSoapMessage(token, relayState, authnRequestElement, null); HttpResponseWrapper httpResponse = getHttpResponse(loc, soapRequest, null); InputStream httpResponseContent = httpResponse.content; SOAPPart idpSoapResponse = SamlProtocol .parseSoapMessage(IOUtils.toString(httpResponseContent, Charset.forName("UTF-8"))); Iterator responseHeaderElements = idpSoapResponse.getEnvelope().getHeader().examineAllHeaderElements(); String newRelayState = ""; while (responseHeaderElements.hasNext()) { SOAPHeaderElement soapHeaderElement = (SOAPHeaderElement) responseHeaderElements.next(); if (RESPONSE.equals(soapHeaderElement.getLocalName())) { String assertionConsumerServiceURL = soapHeaderElement .getAttribute(ASSERTION_CONSUMER_SERVICE_URL); if (!responseConsumerURL.equals(assertionConsumerServiceURL)) { String soapFault = buildSoapFault(ECP_RESPONSE, "The responseConsumerURL does not match the assertionConsumerServiceURL."); httpResponse = getHttpResponse(responseConsumerURL, soapFault, null); message.setContent(InputStream.class, httpResponse.content); return; } } else if (RELAY_STATE.equals(soapHeaderElement.getLocalName())) { newRelayState = DOM2Writer.nodeToString(soapHeaderElement); if (StringUtils.isNotEmpty(relayState) && !relayState.equals(newRelayState)) { LOGGER.debug("RelayState does not match between ECP request and response"); } if (StringUtils.isNotEmpty(relayState)) { newRelayState = relayState; } } } checkSamlpResponse(idpSoapResponse); Element samlpResponseElement = SamlProtocol .getDomElement(idpSoapResponse.getEnvelope().getBody().getFirstChild()); Response paosResponse = null; if (StringUtils.isNotEmpty(messageId)) { paosResponse = getPaosResponse(messageId); } String soapResponse = buildSoapMessage(null, newRelayState, samlpResponseElement, paosResponse); httpResponse = getHttpResponse(responseConsumerURL, soapResponse, message.getExchange().getOutMessage()); if (httpResponse.statusCode < 400) { httpResponseContent = httpResponse.content; message.setContent(InputStream.class, httpResponseContent); Map<String, List<String>> headers = new HashMap<>(); message.put(Message.PROTOCOL_HEADERS, headers); httpResponse.headers.forEach((entry) -> headers.put(entry.getKey(), // CXF Expects pairs of <String, List<String>> entry.getValue() instanceof List ? ((List<Object>) entry.getValue()).stream().map(String::valueOf) .collect(Collectors.toList()) : Lists.newArrayList(String.valueOf(entry.getValue())))); } else { throw new Fault( new AccessDeniedException("Unable to complete SAML ECP connection due to an error.")); } } catch (IOException e) { LOGGER.debug("Error encountered while performing ECP handshake.", e); } catch (XMLStreamException | SOAPException e) { throw new Fault(new AccessDeniedException( "Unable to complete SAML ECP connection. The server's response was not in the correct format.")); } catch (WSSecurityException e) { throw new Fault(new AccessDeniedException( "Unable to complete SAML ECP connection. Unable to send SOAP request messages.")); } } private boolean isRedirectable(String method) { return "HEAD".equals(method) || "GET".equals(method) || "CONNECT".equals(method); } private String createToken(String authorization) throws IOException { String token = null; if (authorization != null) { if (StringUtils.startsWithIgnoreCase(authorization, BASIC)) { byte[] decode = Base64.getDecoder().decode(authorization.split("\\s")[1]); if (decode != null) { String userPass = new String(decode, StandardCharsets.UTF_8); String[] authComponents = userPass.split(":"); if (authComponents.length == 2) { token = getUsernameToken(authComponents[0], authComponents[1]); } else if ((authComponents.length == 1) && (userPass.endsWith(":"))) { token = getUsernameToken(authComponents[0], ""); } } } else if (StringUtils.startsWithIgnoreCase(authorization, SAML)) { token = RestSecurity.inflateBase64(authorization.split("\\s")[1]); } } return token; } @VisibleForTesting HttpResponseWrapper getHttpResponse(String responseConsumerURL, String soapResponse, Message message) throws IOException { // This used to use the ApacheHttpTransport which appeared to not work with 2 way TLS auth but // this one does HttpTransport httpTransport = new NetHttpTransport(); HttpContent httpContent = new ByteArrayContent(TEXT_XML, soapResponse.getBytes("UTF-8")); HttpRequest httpRequest = httpTransport.createRequestFactory() .buildPostRequest(new GenericUrl(responseConsumerURL), httpContent); HttpUnsuccessfulResponseHandler httpUnsuccessfulResponseHandler = getHttpUnsuccessfulResponseHandler( message); httpRequest.setUnsuccessfulResponseHandler(httpUnsuccessfulResponseHandler); httpRequest.getHeaders().put(SOAP_ACTION, HTTP_WWW_OASIS_OPEN_ORG_COMMITTEES_SECURITY); // has 20 second timeout by default HttpResponse httpResponse = httpRequest.execute(); HttpResponseWrapper httpResponseWrapper = new HttpResponseWrapper(); httpResponseWrapper.statusCode = httpResponse.getStatusCode(); httpResponseWrapper.content = httpResponse.getContent(); httpResponseWrapper.headers = httpResponse.getHeaders().entrySet(); return httpResponseWrapper; } @VisibleForTesting HttpUnsuccessfulResponseHandler getHttpUnsuccessfulResponseHandler(Message message) { return (request, response, supportsRetry) -> { String redirectLocation = response.getHeaders().getLocation(); if (isRedirect(request, response, redirectLocation)) { String method = (String) message.get(Message.HTTP_REQUEST_METHOD); HttpContent content = null; if (!isRedirectable(method)) { try (TemporaryFileBackedOutputStream tfbos = new TemporaryFileBackedOutputStream()) { message.setContent(OutputStream.class, tfbos); BodyWriter bodyWriter = new BodyWriter(); bodyWriter.handleMessage(message); ByteSource byteSource = tfbos.asByteSource(); content = new InputStreamContent((String) message.get(Message.CONTENT_TYPE), byteSource.openStream()).setLength(byteSource.size()); } } // resolve the redirect location relative to the current location request.setUrl(new GenericUrl(request.getUrl().toURL(redirectLocation))); request.setRequestMethod(method); request.setContent(content); // remove Authorization and If-* headers request.getHeaders().setAuthorization((String) null); request.getHeaders().setIfMatch(null); request.getHeaders().setIfNoneMatch(null); request.getHeaders().setIfModifiedSince(null); request.getHeaders().setIfUnmodifiedSince(null); request.getHeaders().setIfRange(null); request.getHeaders().setCookie((String) ((List) response.getHeaders().get("set-cookie")).get(0)); Map<String, List<String>> headers = (Map<String, List<String>>) message .get(Message.PROTOCOL_HEADERS); headers.forEach((key, value) -> request.getHeaders().set(key, value)); return true; } return false; }; } @VisibleForTesting boolean isRedirect(HttpRequest request, HttpResponse response, String redirectLocation) { return request.getFollowRedirects() && HttpStatusCodes.isRedirect(response.getStatusCode()) && redirectLocation != null; } @Override public void handleFault(Message message) { LOGGER.debug("PAOS interceptor fault method called."); } private String buildSoapMessage(String token, String relayState, Element body, Response paosResponse) throws WSSecurityException { String updatedMessage = soapMessage.replace("{{XmlBody}}", DOM2Writer.nodeToString(body)); if (token != null) { String updatedSecHdr = securityHeader.replace("{{token}}", token); updatedMessage = updatedMessage.replace("{{WSSecurity}}", updatedSecHdr); } else { updatedMessage = updatedMessage.replace("{{WSSecurity}}", ""); } if (paosResponse != null) { updatedMessage = updatedMessage.replace("{{PAOSResponse}}", convertXmlObjectToString(paosResponse)); } else { updatedMessage = updatedMessage.replace("{{PAOSResponse}}", ""); } updatedMessage = updatedMessage.replace("{{ECPRelayState}}", relayState); return updatedMessage; } private String buildSoapFault(String faultcode, String faultstring) { String updatedFault = soapfaultMessage.replace("{{faultcode}}", faultcode); updatedFault = updatedFault.replace("{{faultstring}}", faultstring); return updatedFault; } private String convertXmlObjectToString(XMLObject xmlObject) throws WSSecurityException { ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); Thread.currentThread().setContextClassLoader(PaosInInterceptor.class.getClassLoader()); try { Document doc = DOMUtils.createDocument(); doc.appendChild(doc.createElement("root")); Element requestElement = OpenSAMLUtil.toDom(xmlObject, null); return DOM2Writer.nodeToString(requestElement); } finally { Thread.currentThread().setContextClassLoader(contextClassLoader); } } private Response getPaosResponse(String messageId) { ResponseBuilder responseBuilder = new ResponseBuilder(); Response response = responseBuilder.buildObject(); response.setRefToMessageID(messageId); return response; } private void checkAuthnRequest(SOAPPart soapRequest) throws IOException { XMLObject authnXmlObj = null; try { Node node = soapRequest.getEnvelope().getBody().getFirstChild(); authnXmlObj = SamlProtocol.getXmlObjectFromNode(node); } catch (WSSecurityException | SOAPException | XMLStreamException ex) { throw new IOException("Unable to convert AuthnRequest document to XMLObject."); } if (authnXmlObj == null) { throw new IOException("AuthnRequest object is not Found."); } if (!(authnXmlObj instanceof AuthnRequest)) { throw new IOException("SAMLRequest object is not AuthnRequest."); } } private void checkSamlpResponse(SOAPPart soapRequest) throws IOException { XMLObject responseXmlObj = null; try { Node node = soapRequest.getEnvelope().getBody().getFirstChild(); responseXmlObj = SamlProtocol.getXmlObjectFromNode(node); } catch (WSSecurityException | SOAPException | XMLStreamException ex) { throw new IOException("Unable to convert Response document to XMLObject."); } if (responseXmlObj == null) { throw new IOException("Response object is not Found."); } if (!(responseXmlObj instanceof org.opensaml.saml.saml2.core.Response)) { throw new IOException("SAMLRequest object is not org.opensaml.saml.saml2.core.Response."); } } private String getUsernameToken(String username, String password) { String updatedToken = usernameToken.replace("{{username}}", username); updatedToken = updatedToken.replace("{{password}}", password); return updatedToken; } static class HttpResponseWrapper { int statusCode; InputStream content; Set<Entry<String, Object>> headers; } }