org.codice.ddf.security.idp.client.AssertionConsumerService.java Source code

Java tutorial

Introduction

Here is the source code for org.codice.ddf.security.idp.client.AssertionConsumerService.java

Source

/**
 * Copyright (c) Codice Foundation
 * <p>
 * This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
 * General Public License as published by the Free Software Foundation, either version 3 of the
 * License, or any later version.
 * <p>
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
 * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details. A copy of the GNU Lesser General Public License
 * is distributed along with this program and can be found at
 * <http://www.gnu.org/licenses/lgpl.html>.
 */
package org.codice.ddf.security.idp.client;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Map;

import javax.servlet.Filter;
import javax.servlet.ServletException;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.ws.rs.FormParam;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.xml.stream.XMLStreamException;

import org.apache.commons.lang.StringUtils;
import org.apache.cxf.helpers.DOMUtils;
import org.apache.cxf.staxutils.StaxUtils;
import org.apache.wss4j.common.crypto.Crypto;
import org.apache.wss4j.common.crypto.CryptoType;
import org.apache.wss4j.common.ext.WSSecurityException;
import org.apache.wss4j.common.saml.OpenSAMLUtil;
import org.apache.wss4j.common.util.DOM2Writer;
import org.codice.ddf.configuration.SystemBaseUrl;
import org.codice.ddf.security.common.HttpUtils;
import org.codice.ddf.security.common.jaxrs.RestSecurity;
import org.codice.ddf.security.filter.websso.WebSSOFilter;
import org.codice.ddf.security.handler.api.HandlerResult;
import org.codice.ddf.security.handler.saml.SAMLAssertionHandler;
import org.codice.ddf.security.policy.context.ContextPolicy;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Document;

import ddf.security.http.SessionFactory;
import ddf.security.samlp.SamlProtocol;
import ddf.security.samlp.SimpleSign;
import ddf.security.samlp.SystemCrypto;
import ddf.security.samlp.ValidationException;
import ddf.security.samlp.impl.RelayStates;

@Path("sso")
public class AssertionConsumerService {

    private static final Logger LOGGER = LoggerFactory.getLogger(IdpHandler.class);

    private static final String SAML_RESPONSE = "SAMLResponse";

    private static final String RELAY_STATE = "RelayState";

    private static final String SIG_ALG = "SigAlg";

    private static final String SIGNATURE = "Signature";

    private static final String UNABLE_TO_LOGIN = "Unable to login with provided AuthN response assertion.";

    private final SimpleSign simpleSign;

    private final IdpMetadata idpMetadata;

    private final RelayStates<String> relayStates;

    @Context
    private HttpServletRequest request;

    private Filter loginFilter;

    private SystemCrypto systemCrypto;

    private SessionFactory sessionFactory;

    static {
        OpenSAMLUtil.initSamlEngine();
    }

    public AssertionConsumerService(SimpleSign simpleSign, IdpMetadata metadata, SystemCrypto crypto,
            RelayStates<String> relayStates) {
        this.simpleSign = simpleSign;
        idpMetadata = metadata;
        systemCrypto = crypto;
        this.relayStates = relayStates;
    }

    @POST
    @Produces(MediaType.APPLICATION_FORM_URLENCODED)
    public Response postSamlResponse(@FormParam(SAML_RESPONSE) String encodedSamlResponse,
            @FormParam(RELAY_STATE) String relayState) {

        return processSamlResponse(decodeBase64(encodedSamlResponse), relayState);
    }

    @GET
    public Response getSamlResponse(@QueryParam(SAML_RESPONSE) String deflatedSamlResponse,
            @QueryParam(RELAY_STATE) String relayState, @QueryParam(SIG_ALG) String signatureAlgorithm,
            @QueryParam(SIGNATURE) String signature) {

        if (validateSignature(deflatedSamlResponse, relayState, signatureAlgorithm, signature)) {
            try {
                return processSamlResponse(RestSecurity.inflateBase64(deflatedSamlResponse), relayState);
            } catch (IOException e) {
                String msg = "Unable to decode and inflate AuthN response.";
                LOGGER.warn(msg, e);
                return Response.serverError().entity(msg).build();
            }
        } else {
            return Response.serverError().entity("Invalid AuthN response signature.").build();
        }

    }

    private boolean validateSignature(String deflatedSamlResponse, String relayState, String signatureAlgorithm,
            String signature) {
        boolean signaturePasses = false;
        if (signature != null) {
            if (StringUtils.isNotBlank(deflatedSamlResponse) && StringUtils.isNotBlank(relayState)
                    && StringUtils.isNotBlank(signatureAlgorithm)) {
                try {
                    String signedMessage = String.format("%s=%s&%s=%s&%s=%s", SAML_RESPONSE,
                            URLEncoder.encode(deflatedSamlResponse, "UTF-8"), RELAY_STATE,
                            URLEncoder.encode(relayState, "UTF-8"), SIG_ALG,
                            URLEncoder.encode(signatureAlgorithm, "UTF-8"));
                    signaturePasses = simpleSign.validateSignature(signedMessage, signature,
                            idpMetadata.getSigningCertificate());
                } catch (SimpleSign.SignatureException | UnsupportedEncodingException e) {
                    LOGGER.debug("Failed to validate AuthN response signature.", e);
                }
            }
        } else {
            LOGGER.warn("Received unsigned AuthN response.  Could not verify IDP identity or response integrity.");
            signaturePasses = true;
        }

        return signaturePasses;
    }

    public Response processSamlResponse(String authnResponse, String relayState) {
        LOGGER.trace(authnResponse);

        org.opensaml.saml.saml2.core.Response samlResponse = extractSamlResponse(authnResponse);
        if (samlResponse == null) {
            return Response.serverError().entity("Unable to parse AuthN response.").build();
        }

        if (!validateResponse(samlResponse)) {
            return Response.serverError().entity("AuthN response failed validation.").build();
        }

        String redirectLocation = relayStates.decode(relayState);
        if (StringUtils.isBlank(redirectLocation)) {
            return Response.serverError().entity("AuthN response returned unknown or expired relay state.").build();
        }

        if (!login(samlResponse)) {
            return Response.serverError().entity(UNABLE_TO_LOGIN).build();
        }

        URI relayUri;
        try {
            relayUri = new URI(redirectLocation);
        } catch (URISyntaxException e) {
            LOGGER.warn("Unable to parse relay state.", e);
            return Response.serverError().entity("Unable to redirect back to original location.").build();
        }

        LOGGER.trace("Successfully logged in.  Redirecting to {}", relayUri.toString());
        return Response.seeOther(relayUri).build();
    }

    private boolean validateResponse(org.opensaml.saml.saml2.core.Response samlResponse) {
        try {
            AuthnResponseValidator validator = new AuthnResponseValidator(simpleSign);
            validator.validate(samlResponse);
        } catch (ValidationException e) {
            LOGGER.warn("Invalid AuthN response received from " + samlResponse.getIssuer(), e);
            return false;
        }

        return true;
    }

    public void setSessionFactory(SessionFactory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    private boolean login(org.opensaml.saml.saml2.core.Response samlResponse) {
        if (!request.isSecure()) {
            return false;
        }
        Map<String, Cookie> cookieMap = HttpUtils.getCookieMap(request);
        if (cookieMap.containsKey("JSESSIONID")) {
            sessionFactory.getOrCreateSession(request).invalidate();
        }
        String assertionValue = DOM2Writer.nodeToString(samlResponse.getAssertions().get(0).getDOM());

        String encodedAssertion;
        try {
            encodedAssertion = RestSecurity.deflateAndBase64Encode(assertionValue);
        } catch (IOException e) {
            LOGGER.warn("Unable to deflate and encode assertion.", e);
            return false;
        }

        final String authHeader = RestSecurity.SAML_HEADER_PREFIX + encodedAssertion;

        HttpServletRequestWrapper wrappedRequest = new HttpServletRequestWrapper(request) {
            @Override
            public String getHeader(String name) {
                if (RestSecurity.AUTH_HEADER.equals(name)) {
                    return authHeader;
                }
                return super.getHeader(name);
            }

            @Override
            public Object getAttribute(String name) {
                if (ContextPolicy.ACTIVE_REALM.equals(name)) {
                    return "idp";
                }
                return super.getAttribute(name);
            }
        };

        SAMLAssertionHandler samlAssertionHandler = new SAMLAssertionHandler();

        LOGGER.trace("Processing SAML assertion with SAML Handler.");
        HandlerResult samlResult = samlAssertionHandler.getNormalizedToken(wrappedRequest, null, null, false);

        if (samlResult.getStatus() != HandlerResult.Status.COMPLETED) {
            LOGGER.debug("Failed to handle SAML assertion.");
            return false;
        }

        request.setAttribute(WebSSOFilter.DDF_AUTHENTICATION_TOKEN, samlResult);
        request.removeAttribute(ContextPolicy.NO_AUTH_POLICY);

        try {
            LOGGER.trace("Trying to login with provided SAML assertion.");
            loginFilter.doFilter(wrappedRequest, null, (servletRequest, servletResponse) -> {
            });
        } catch (IOException | ServletException e) {
            LOGGER.debug("Failed to apply login filter to SAML assertion", e);
            return false;
        }

        return true;
    }

    @GET
    @Path("/metadata")
    @Produces("application/xml")
    public Response retrieveMetadata() throws WSSecurityException, CertificateEncodingException {
        X509Certificate issuerCert = findCertificate(systemCrypto.getSignatureAlias(),
                systemCrypto.getSignatureCrypto());
        X509Certificate encryptionCert = findCertificate(systemCrypto.getEncryptionAlias(),
                systemCrypto.getEncryptionCrypto());

        String hostname = SystemBaseUrl.getHost();
        String port = SystemBaseUrl.getPort();
        String rootContext = SystemBaseUrl.getRootContext();

        String entityId = String.format("https://%s:%s%s/saml", hostname, port, rootContext);

        String logoutLocation = String.format("https://%s:%s%s/saml/logout", hostname, port, rootContext);
        String assertionConsumerServiceLocation = String.format("https://%s:%s%s/saml/sso", hostname, port,
                rootContext);

        EntityDescriptor entityDescriptor = SamlProtocol.createSpMetadata(entityId,
                Base64.getEncoder().encodeToString(issuerCert.getEncoded()),
                Base64.getEncoder().encodeToString(encryptionCert.getEncoded()), logoutLocation,
                assertionConsumerServiceLocation, assertionConsumerServiceLocation);

        Document doc = DOMUtils.createDocument();
        doc.appendChild(doc.createElement("root"));
        return Response.ok(DOM2Writer.nodeToString(OpenSAMLUtil.toDom(entityDescriptor, doc, false))).build();
    }

    private X509Certificate findCertificate(String alias, Crypto crypto) throws WSSecurityException {
        CryptoType cryptoType = new CryptoType(CryptoType.TYPE.ALIAS);
        cryptoType.setAlias(alias);
        X509Certificate[] certs = crypto.getX509Certificates(cryptoType);
        if (certs == null) {
            throw new WSSecurityException(WSSecurityException.ErrorCode.SECURITY_ERROR,
                    "Unable to retrieve certificate");
        }
        return certs[0];
    }

    private org.opensaml.saml.saml2.core.Response extractSamlResponse(String samlResponse) {
        org.opensaml.saml.saml2.core.Response response = null;
        try {
            Document responseDoc = StaxUtils
                    .read(new ByteArrayInputStream(samlResponse.getBytes(StandardCharsets.UTF_8)));
            XMLObject responseXmlObject = OpenSAMLUtil.fromDom(responseDoc.getDocumentElement());

            if (responseXmlObject instanceof org.opensaml.saml.saml2.core.Response) {
                response = (org.opensaml.saml.saml2.core.Response) responseXmlObject;
            }
        } catch (XMLStreamException | WSSecurityException e) {
            LOGGER.debug("Failed to convert AuthN response string to object.", e);
        }

        return response;
    }

    private String decodeBase64(String encoded) {
        return new String(Base64.getMimeDecoder().decode(encoded.getBytes(StandardCharsets.UTF_8)),
                StandardCharsets.UTF_8);
    }

    public Filter getLoginFilter() {
        return loginFilter;
    }

    public void setLoginFilter(Filter loginFilter) {
        this.loginFilter = loginFilter;
    }

    public void setRequest(HttpServletRequest request) {
        this.request = request;
    }
}