org.xlcloud.console.saml2.Saml2ServiceProviderConsumerServlet.java Source code

Java tutorial

Introduction

Here is the source code for org.xlcloud.console.saml2.Saml2ServiceProviderConsumerServlet.java

Source

/*
 * Copyright 2012 AMG.lab, a Bull Group Company
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *    http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.xlcloud.console.saml2;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;

import javax.inject.Inject;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang.StringUtils;
import org.opensaml.saml2.core.Assertion;
import org.opensaml.saml2.core.Attribute;
import org.opensaml.saml2.core.AttributeStatement;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.LogoutResponse;
import org.opensaml.saml2.core.Response;
import org.opensaml.xml.XMLObject;
import org.opensaml.xml.io.UnmarshallingException;
import org.opensaml.xml.parse.XMLParserException;
import org.opensaml.xml.util.Base64;
import org.opensaml.xml.util.XMLObjectHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.xlcloud.config.ConfigParam;
import org.xlcloud.console.context.AuthenticationFilter;
import org.xlcloud.console.context.IdentityContext;

import com.iplanet.sso.SSOException;
import com.iplanet.sso.SSOToken;
import com.iplanet.sso.SSOTokenManager;

/**
 * Servlet reads the SAMLResponses from IDP and processes them.
 * 
 * @author Jakub Wachowski, AMG.net
 */
public class Saml2ServiceProviderConsumerServlet extends HttpServlet {

    private static final long serialVersionUID = 3988387294526919009L;

    private static final Logger LOG = LoggerFactory.getLogger(Saml2ServiceProviderConsumerServlet.class);

    @Inject
    @ConfigParam
    private String logoutSuccessfulPageUrl;

    @Inject
    private IdentityContext context;

    /** {@inheritDoc} */
    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        process(req, resp, true);
    }

    /** {@inheritDoc} */
    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        process(req, resp, false);
    }

    /**
     * Performs the processing of SAML response
     * 
     * @param req
     * @param resp
     * @param deflated
     *            GET uses deflated SAMLResponse, POST does not
     * @throws IOException
     */
    private void process(HttpServletRequest req, HttpServletResponse resp, boolean deflated) throws IOException {
        try {
            String responseMessage = req.getParameter("SAMLResponse");
            if (responseMessage == null) {
                // no SAMLResponse
                resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "SAMLResponse parameter is required");
                return;
            }

            XMLObject samlResponse = createXmlObject(responseMessage, deflated);

            if (samlResponse == null) {
                throw new Saml2ResponseProcessingException("Unable to parse the SAMLResponse");
            }

            processSamlResponse(samlResponse, req, resp);

        } catch (Saml2ResponseProcessingException e) {
            resp.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage());
            return;
        }

    }

    private void processSamlResponse(XMLObject xmlobj, HttpServletRequest req, HttpServletResponse resp)
            throws IOException, Saml2ResponseProcessingException {

        if (xmlobj instanceof Response) {
            // received response to AuthnRequest
            Response samlResponse = (Response) xmlobj;
            processAuthnResponse(req, resp, samlResponse);
            return;
        } else if (xmlobj instanceof LogoutResponse) {
            if (StringUtils.isEmpty(logoutSuccessfulPageUrl)) {
                throw new Saml2ResponseProcessingException(
                        "Missing configuration of successfull logout page ul - logoutSuccessfulPageUrl="
                                + logoutSuccessfulPageUrl);
            }
            resp.sendRedirect(logoutSuccessfulPageUrl);
            return;
        } else {
            resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Unsupported SAMLResponse type");
            return;
        }

    }

    private void processAuthnResponse(HttpServletRequest req, HttpServletResponse resp, Response samlResponse)
            throws IOException, Saml2ResponseProcessingException {
        String inResponseTo = samlResponse.getInResponseTo();

        // check if such AuthnRequest has been really issued
        AuthnInfoHolder authnInfo = Saml2ServiceProvider.getInstance().getAuthnInfo(inResponseTo);
        if (authnInfo == null) {
            throw new Saml2ResponseProcessingException("The SAMLResponse is for an unknown SAMLRequest");
        }

        checkDestination(authnInfo.getAuthnRequest(), samlResponse);

        Assertion assertion = getAssertion(resp, samlResponse);

        List<AttributeStatement> attrStatements = assertion.getAttributeStatements();

        if (attrStatements == null || attrStatements.isEmpty()) {
            throw new Saml2ResponseProcessingException("The SAMLResponse does not contain attrbiutes");
        }

        Map<String, String> attributesMap = createSingleValuedAttributesMap(attrStatements);

        String userName = attributesMap.get("userName");
        Long accountId = safelyParseLong(attributesMap.get("accountId"));
        Long userId = safelyParseLong(attributesMap.get("userId"));
        String token = attributesMap.get("ssoToken");

        if (StringUtils.isBlank(token)) {
            throw new IllegalStateException(
                    "SSO Token has not been returned from the IDP - check the OpenAM's configuration");
        }

        SSOToken ssoToken = createSsoToken(token);

        // userName and userId are required
        if (userId == null || userName == null || userName.isEmpty()) {
            throw new Saml2ResponseProcessingException("The SAMLResponse does not contain all required attributes");
        }

        String sessionId = getSessionId(assertion);
        String nameIdFormat = getNameIdFormat(assertion);

        context.setName(userName);
        context.setUserId(userId);
        context.setAccountId(accountId);
        context.setIdpSessionId(sessionId);
        context.setNameIdFormat(nameIdFormat);
        context.setSsoToken(ssoToken);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Identity Context Created " + context);
        }

        req.getSession().setAttribute(AuthenticationFilter.IDENTITY_CONTEXT_PARAM, context);

        // redirect the user to the initially requested url
        String url = authnInfo.getInitialRequestUrl();
        resp.sendRedirect(url);
    }

    /**
     * @param string
     * @return
     * @throws Saml2ResponseProcessingException 
     */
    private SSOToken createSsoToken(String token) throws Saml2ResponseProcessingException {
        try {
            return SSOTokenManager.getInstance().createSSOToken(token);
        } catch (UnsupportedOperationException e) {
            LOG.error(e.getMessage(), e);
            throw new Saml2ResponseProcessingException(e);
        } catch (SSOException e) {
            LOG.error(e.getMessage(), e);
            throw new Saml2ResponseProcessingException(e);
        }
    }

    /**
     * @param assertionConsumerServiceURL
     * @param destination
     * @throws Saml2ResponseProcessingException 
     */
    private void checkDestination(AuthnRequest authnRequest, Response samlResponse)
            throws Saml2ResponseProcessingException {

        String expected = authnRequest.getAssertionConsumerServiceURL();
        String actual = samlResponse.getDestination();

        if (actual == null) {
            //Destination element is optional
            return;
        }

        if (!actual.equals(expected)) {
            throw new Saml2ResponseProcessingException(
                    "Response destination is invalid: expected: [" + expected + "], actual: [" + actual + "]");
        }

    }

    /**
     * Reads SAML attributes and creates map <attributeName, attributeValue>. If
     * SAML attribute contains more than one value, then the first one is used
     * 
     * @param attrStatements
     * @return
     */
    private Map<String, String> createSingleValuedAttributesMap(List<AttributeStatement> attrStatements) {

        Map<String, String> map = new HashMap<String, String>();

        for (AttributeStatement as : attrStatements) {
            if (as.getAttributes() != null) {
                for (Attribute a : as.getAttributes()) {
                    String aName = a.getName();
                    List<XMLObject> aValues = a.getAttributeValues();
                    if (aValues != null && !aValues.isEmpty()) {
                        String aValue = aValues.get(0).getDOM().getTextContent();
                        map.put(aName, aValue);
                    }
                }
            }
        }

        return map;
    }

    /**
     * @param assertion
     * @return
     */
    private String getNameIdFormat(Assertion assertion) {
        String nameIdFormat = assertion.getSubject().getNameID().getFormat();
        return nameIdFormat;
    }

    private Assertion getAssertion(HttpServletResponse resp, Response samlResponse)
            throws Saml2ResponseProcessingException {
        if (samlResponse.getAssertions() == null || samlResponse.getAssertions().size() != 1) {
            throw new Saml2ResponseProcessingException("The SAMLResponse does not contain exactly one Assertion");
        } else {
            Assertion assertion = samlResponse.getAssertions().get(0);
            return assertion;
        }
    }

    /**
     * @param assertion
     * @return
     * @throws Saml2ResponseProcessingException
     */
    private String getSessionId(Assertion assertion) throws Saml2ResponseProcessingException {

        if (assertion.getAuthnStatements() == null || assertion.getAuthnStatements().size() != 1) {
            throw new Saml2ResponseProcessingException(
                    "The SAMLResponse does not contain exactly one AuthnStatement");
        }

        String sessionId = assertion.getAuthnStatements().get(0).getSessionIndex();
        return sessionId;
    }

    /**
     * @param aValue
     * @return
     */
    private Long safelyParseLong(String value) {
        try {
            return Long.parseLong(value);
        } catch (NumberFormatException e) {
            // do nothing
        }
        return null;
    }

    private XMLObject createXmlObject(String responseMessage, boolean deflated)
            throws Saml2ResponseProcessingException {
        byte[] responseBytes = Base64.decode(responseMessage);

        if (deflated) {
            try {
                responseBytes = inflate(responseBytes);
            } catch (IOException e) {
                LOG.info(e.getMessage(), e);
                throw new Saml2ResponseProcessingException(e);
            }
        }

        Saml2MessageLogger.log(responseBytes);

        InputStream inputStream = new ByteArrayInputStream(responseBytes);
        XMLObject xmlObj;
        try {
            xmlObj = XMLObjectHelper.unmarshallFromInputStream(Saml2ServiceProvider.getInstance().getParserPool(),
                    inputStream);
            return xmlObj;
        } catch (XMLParserException e) {
            LOG.info(e.getMessage(), e);
            throw new Saml2ResponseProcessingException(e);
        } catch (UnmarshallingException e) {
            LOG.info(e.getMessage(), e);
            throw new Saml2ResponseProcessingException(e);
        }
    }

    private byte[] inflate(byte[] responseBytes) throws IOException {
        Inflater inflater = new Inflater(true);
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        InflaterOutputStream ios = new InflaterOutputStream(out, inflater);

        ios.write(responseBytes);
        ios.close();
        return out.toByteArray();
    }
}