org.surfnet.oaaas.conext.SAMLAuthenticator.java Source code

Java tutorial

Introduction

Here is the source code for org.surfnet.oaaas.conext.SAMLAuthenticator.java

Source

/*
 * Copyright 2012 SURFnet bv, The Netherlands
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.surfnet.oaaas.conext;

import nl.surfnet.coin.api.client.OpenConextOAuthClient;
import nl.surfnet.coin.api.client.domain.Group20;
import nl.surfnet.spring.security.opensaml.AuthnRequestGenerator;
import nl.surfnet.spring.security.opensaml.Provisioner;
import nl.surfnet.spring.security.opensaml.SAMLMessageHandler;
import nl.surfnet.spring.security.opensaml.ServiceProviderAuthenticationException;
import nl.surfnet.spring.security.opensaml.util.IDService;
import nl.surfnet.spring.security.opensaml.util.TimeService;
import nl.surfnet.spring.security.opensaml.xml.EndpointGenerator;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.lang.StringUtils;
import org.opensaml.common.binding.SAMLMessageContext;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.RequesterID;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.core.Scoping;
import org.opensaml.saml2.core.impl.RequesterIDBuilder;
import org.opensaml.saml2.core.impl.ScopingBuilder;
import org.opensaml.saml2.metadata.Endpoint;
import org.opensaml.saml2.metadata.SingleSignOnService;
import org.opensaml.ws.message.decoder.MessageDecodingException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.opensaml.xml.security.CriteriaSet;
import org.opensaml.xml.security.credential.Credential;
import org.opensaml.xml.security.credential.UsageType;
import org.opensaml.xml.security.criteria.EntityIDCriteria;
import org.opensaml.xml.security.criteria.UsageCriteria;
import org.opensaml.xml.validation.ValidationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.support.PropertiesLoaderUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.surfnet.oaaas.auth.AbstractAuthenticator;
import org.surfnet.oaaas.auth.principal.AuthenticatedPrincipal;
import org.surfnet.oaaas.model.AuthorizationRequest;
import org.surfnet.oaaas.model.Client;
import org.surfnet.oaaas.repository.AuthorizationRequestRepository;

import javax.inject.Inject;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Properties;

@Component
public class SAMLAuthenticator extends AbstractAuthenticator {

    private static final Logger LOG = LoggerFactory.getLogger(SAMLAuthenticator.class);
    private static final String RELAY_STATE_FROM_SAML = "RELAY_STATE_FROM_SAML";
    private static final String PRINCIPAL_FROM_SAML = "PRINCIPAL_FROM_SAML";
    private static final String CLIENT_SAML_ENTITY_NAME = "CLIENT_SAML_ENTITY_NAME";

    private TimeService timeService = new TimeService();
    private IDService idService = new IDService();
    private ScopingBuilder scopingBuilder = new ScopingBuilder();
    private RequesterIDBuilder requesterIDBuilder = new RequesterIDBuilder();

    private OpenSAMLContext openSAMLContext;
    private OpenConextOAuthClient apiClient;
    private String callbackFlagParameter = "apiOauthCallback";
    private boolean enrichPricipal;
    private String adminGroup;

    @Inject
    private AuthorizationRequestRepository authorizationRequestRepository;

    private final Properties properties;

    {
        try {
            properties = PropertiesLoaderUtils.loadAllProperties("surfconext.authn.properties");
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        try {
            super.init(filterConfig);
            openSAMLContext = createOpenSAMLContext(properties);
            enrichPricipal = Boolean.valueOf(properties.getProperty("api-enrich-principal"));
            if (enrichPricipal) {
                apiClient = createOpenConextOAuthClient(properties);
                adminGroup = properties.getProperty("admin.client.apis.teamname");
            }
        } catch (Exception e) {
            throw new ServletException(e);
        }
    }

    protected OpenConextOAuthClient createOpenConextOAuthClient(Properties properties)
            throws ClassNotFoundException, IllegalAccessException, InstantiationException,
            InvocationTargetException {
        OpenConextOAuthClient apiClient = (OpenConextOAuthClient) getClass().getClassLoader()
                .loadClass(properties.getProperty("openConextApiClient")).newInstance();
        BeanUtils.setProperty(apiClient, "callbackUrl", properties.getProperty("api-callbackuri"));
        BeanUtils.setProperty(apiClient, "consumerSecret", properties.getProperty("api-consumersecret"));
        BeanUtils.setProperty(apiClient, "consumerKey", properties.getProperty("api-consumerkey"));
        BeanUtils.setProperty(apiClient, "endpointBaseUrl", properties.getProperty("api-baseurl"));
        return apiClient;
    }

    /**
     * Default Context factory method.
     */
    protected OpenSAMLContext createOpenSAMLContext(Properties properties) {
        return new OpenSAMLContext(properties, createProvisioner());
    }

    /**
     * Default Provisioner factory method.
     */
    protected Provisioner createProvisioner() {
        SAMLProvisioner samlProvisioner = new SAMLProvisioner();
        samlProvisioner.setUuidAttribute((String) properties.get("samlUuidAttribute"));
        return samlProvisioner;
    }

    @Override
    public boolean canCommence(HttpServletRequest request) {
        return isSAMLResponse(request) || isOAuthCallback(request);
    }

    @Override
    public void authenticate(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
            String authStateValue, String returnUri) throws IOException, ServletException {
        LOG.debug("Hitting SAML Authenticator filter");
        if (isSAMLResponse(request)) {
            Response samlResponse = extractSamlResponse(request);
            SAMLAuthenticatedPrincipal principal = (SAMLAuthenticatedPrincipal) openSAMLContext.assertionConsumer()
                    .consume(samlResponse);
            if (enrichPricipal) {
                //need to save the Principal and the AuthState somewhere
                request.getSession().setAttribute(PRINCIPAL_FROM_SAML, principal);
                request.getSession().setAttribute(RELAY_STATE_FROM_SAML, getSAMLRelayState(request));
                response.sendRedirect(apiClient.getAuthorizationUrl());
            } else {
                proceedWithChain(request, response, chain, principal, getSAMLRelayState(request));
            }
        } else if (isOAuthCallback(request)) {
            SAMLAuthenticatedPrincipal principal = (SAMLAuthenticatedPrincipal) request.getSession()
                    .getAttribute(PRINCIPAL_FROM_SAML);
            String authState = (String) request.getSession().getAttribute(RELAY_STATE_FROM_SAML);
            if (principal == null) { //huh
                throw new ServiceProviderAuthenticationException("No principal anymore in the session");
            }
            String userId = principal.getName();
            if (StringUtils.isEmpty(userId)) {
                throw new ServiceProviderAuthenticationException("No userId in SAML assertion!");
            }
            apiClient.oauthCallback(request, userId);
            List<Group20> groups = apiClient.getGroups20(userId, userId);
            if (!CollectionUtils.isEmpty(groups)) {
                for (Group20 group : groups) {
                    principal.addGroup(group.getId());
                    if (StringUtils.isNotBlank(this.adminGroup) && adminGroup.equalsIgnoreCase(group.getId())) {
                        principal.setAdminPrincipal(true);
                    }
                }
            }
            proceedWithChain(request, response, chain, principal, authState);
        } else {
            sendAuthnRequest(response, authStateValue, getReturnUri(request));
        }
    }

    private void proceedWithChain(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
            AuthenticatedPrincipal principal, String authStateValue) throws IOException, ServletException {
        super.setPrincipal(request, principal);
        super.setAuthStateValue(request, authStateValue);
        chain.doFilter(request, response);
    }

    private boolean isOAuthCallback(HttpServletRequest request) {
        return request.getParameter(callbackFlagParameter) != null;
    }

    protected String getSAMLRelayState(HttpServletRequest request) {
        return request.getParameter("RelayState");
    }

    protected boolean isSAMLResponse(HttpServletRequest request) {
        return request.getParameter("SAMLResponse") != null;
    }

    private Response extractSamlResponse(HttpServletRequest request) {

        SAMLMessageContext messageContext;

        final SAMLMessageHandler samlMessageHandler = openSAMLContext.samlMessageHandler();
        try {
            messageContext = samlMessageHandler.extractSAMLMessageContext(request);
        } catch (MessageDecodingException me) {
            throw new ServiceProviderAuthenticationException("Could not decode SAML Response", me);
        } catch (org.opensaml.xml.security.SecurityException se) {
            throw new ServiceProviderAuthenticationException("Could not decode SAML Response", se);
        }

        LOG.debug("Message received from issuer: " + messageContext.getInboundMessageIssuer());

        if (!(messageContext.getInboundSAMLMessage() instanceof Response)) {
            throw new ServiceProviderAuthenticationException("SAML Message was not a Response.");
        }

        final Response inboundSAMLMessage = (Response) messageContext.getInboundSAMLMessage();

        try {
            openSAMLContext.validatorSuite().validate(inboundSAMLMessage);
            return inboundSAMLMessage;
        } catch (ValidationException ve) {
            LOG.warn("Response Message failed Validation", ve);
            throw new RuntimeException("Invalid SAML Response Message", ve);
        }
    }

    private void sendAuthnRequest(HttpServletResponse response, String authState, String returnUri)
            throws IOException {
        AuthnRequestGenerator authnRequestGenerator = new AuthnRequestGenerator(openSAMLContext.entityId(),
                timeService, idService);
        EndpointGenerator endpointGenerator = new EndpointGenerator();

        final String target = openSAMLContext.getIdpUrl();

        Endpoint endpoint = endpointGenerator.generateEndpoint(SingleSignOnService.DEFAULT_ELEMENT_NAME, target,
                openSAMLContext.assertionConsumerUri());

        AuthnRequest authnRequest = authnRequestGenerator.generateAuthnRequest(target,
                openSAMLContext.assertionConsumerUri());

        Client client = getClientByRequest(authState);
        String spEntityIdBy = client.getAttributes().get(CLIENT_SAML_ENTITY_NAME);

        if (StringUtils.isNotEmpty(spEntityIdBy)) {
            Scoping scoping = scopingBuilder.buildObject();
            scoping.getRequesterIDs().add(createRequesterID(spEntityIdBy));
            authnRequest.setScoping(scoping);
        } else {
            LOG.warn(
                    "For Client {} there is no key CLIENT_SAML_ENTITY_NAME configured to identify the SP entity name. NO SCOPING IS APPLIED",
                    client.getClientId());
        }

        CriteriaSet criteriaSet = new CriteriaSet();
        criteriaSet.add(new EntityIDCriteria(openSAMLContext.entityId()));
        criteriaSet.add(new UsageCriteria(UsageType.SIGNING));
        try {

            Credential signingCredential = openSAMLContext.keyStoreCredentialResolver().resolveSingle(criteriaSet);
            String relayState = authState;
            LOG.debug("Sending authnRequest to {}", target);
            openSAMLContext.samlMessageHandler().sendSAMLMessage(authnRequest, endpoint, response, relayState,
                    signingCredential);
        } catch (MessageEncodingException mee) {
            LOG.error("Could not send authnRequest to Identity Provider.", mee);
            response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
        } catch (org.opensaml.xml.security.SecurityException e) {
            throw new RuntimeException(e);
        }
    }

    private RequesterID createRequesterID(String id) {
        RequesterID requesterID = requesterIDBuilder.buildObject();
        requesterID.setRequesterID(id);
        return requesterID;
    }

    /**
     * Get the Client
     */
    protected Client getClientByRequest(String authState) {
        AuthorizationRequest authorizationRequest = authorizationRequestRepository.findByAuthState(authState);
        return authorizationRequest.getClient();

    }

}