com.google.enterprise.adaptor.SamlIdentityProvider.java Source code

Java tutorial

Introduction

Here is the source code for com.google.enterprise.adaptor.SamlIdentityProvider.java

Source

// Copyright 2013 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.enterprise.adaptor;

import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeAssertion;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeAttribute;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeAttributeStatement;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeAttributeValue;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeAudienceRestriction;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeAuthnStatement;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeConditions;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeResponse;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeStatus;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeStatusCode;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeStatusMessage;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeSubject;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeSubjectConfirmation;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeSubjectConfirmationData;
import static com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil.makeSuccessfulStatus;

import com.google.enterprise.adaptor.secmgr.saml.OpenSamlUtil;

import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;

import org.apache.velocity.app.VelocityEngine;
import org.apache.velocity.runtime.log.JdkLogChute;
import org.apache.velocity.runtime.resource.loader.ClasspathResourceLoader;
import org.joda.time.DateTime;
import org.opensaml.common.binding.SAMLMessageContext;
import org.opensaml.common.xml.SAMLConstants;
import org.opensaml.saml2.binding.AuthnResponseEndpointSelector;
import org.opensaml.saml2.binding.decoding.HTTPRedirectDeflateDecoder;
import org.opensaml.saml2.binding.encoding.HTTPPostEncoder;
import org.opensaml.saml2.core.Attribute;
import org.opensaml.saml2.core.AuthnContext;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.NameID;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.core.StatusCode;
import org.opensaml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml2.metadata.Endpoint;
import org.opensaml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml2.metadata.IDPSSODescriptor;
import org.opensaml.saml2.metadata.RoleDescriptor;
import org.opensaml.saml2.metadata.SPSSODescriptor;
import org.opensaml.ws.message.decoder.MessageDecodingException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.opensaml.xml.security.SecurityException;
import org.opensaml.xml.security.SecurityHelper;
import org.opensaml.xml.security.credential.Credential;

import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.KeyPair;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Provides ability to recieve and respond to SAML authn requests.
 *
 * <p>This functions as the Identity Provider (IdP) role in SAML. An IdP
 * authenticates users when requested by a Service Provider (SP) and sends the
 * results to the SP.
 */
class SamlIdentityProvider {
    private static final Logger log = Logger.getLogger(SamlIdentityProvider.class.getName());
    private static final VelocityEngine velocityEngine;

    static {
        velocityEngine = new VelocityEngine();
        velocityEngine.addProperty("resource.loader", "classloader");
        velocityEngine.addProperty("classloader.resource.loader.class", ClasspathResourceLoader.class.getName());
        velocityEngine.addProperty("runtime.log.logsystem.class", JdkLogChute.class.getName());
        try {
            velocityEngine.init();
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    private final AuthnAuthority adaptor;
    /** Credentials to use to sign messages. */
    private final Credential cred;
    private final SamlMetadata metadata;
    private final SsoHandler ssoHandler = new SsoHandler();
    private final int expirationMillis;

    public SamlIdentityProvider(AuthnAuthority adaptor, SamlMetadata metadata, KeyPair key,
            int expirationMilliseconds) {
        if (adaptor == null || metadata == null) {
            throw new NullPointerException();
        }
        this.adaptor = adaptor;
        this.metadata = metadata;
        this.cred = (key == null) ? null : SecurityHelper.getSimpleCredential(key.getPublic(), key.getPrivate());
        if (expirationMilliseconds <= 0) {
            throw new IllegalArgumentException("expiration needs to be positive");
        }
        this.expirationMillis = expirationMilliseconds;
    }

    public void respond(HttpExchange ex, SAMLMessageContext<AuthnRequest, Response, NameID> context,
            AuthnIdentity identity) throws IOException {
        Response samlResponse = createResponse(context, identity);

        context.setOutboundSAMLMessage(samlResponse);
        context.setOutboundMessageTransport(new HttpExchangeOutTransportAdapter(ex));

        String responseBinding = context.getPeerEntityEndpoint().getBinding();
        if (!SAMLConstants.SAML2_POST_BINDING_URI.equals(responseBinding)) {
            throw new IllegalStateException("Unknown SAML binding: " + responseBinding);
        }
        try {
            new HTTPPostEncoder(velocityEngine, "/templates/saml2-post-binding.vm").encode(context);
        } catch (MessageEncodingException e) {
            throw new IOException("Failed to encode SAML response", e);
        }
        ex.getResponseBody().flush();
        ex.getResponseBody().close();
        ex.close();
    }

    private Response createResponse(SAMLMessageContext<AuthnRequest, Response, NameID> context,
            AuthnIdentity identity) {
        String recipient = context.getPeerEntityEndpoint().getLocation();
        String audience = context.getInboundMessageIssuer();
        String inResponseTo = context.getInboundSAMLMessage().getID();
        String issuer = context.getLocalEntityId();
        DateTime now = new DateTime();
        // Expiration time in the future.
        DateTime expirationTime = now.plusMillis(expirationMillis);

        if (identity == null) {
            return makeResponse(issuer, now, makeStatus(makeStatusCode(StatusCode.RESPONDER_URI),
                    makeStatusMessage("Could not authenticate user")), inResponseTo);
        }

        Attribute groupsAttribute = makeAttribute("member-of");
        Iterable<GroupPrincipal> groups = identity.getGroups();
        if (groups == null) {
            groups = Collections.emptySet();
        }
        for (GroupPrincipal group : groups) {
            String name = group.getName();
            groupsAttribute.getAttributeValues().add(makeAttributeValue(name));
        }

        return makeResponse(issuer, now, makeSuccessfulStatus(), inResponseTo,
                makeAssertion(issuer, now,
                        makeSubject(identity.getUser().getName(),
                                makeSubjectConfirmation(OpenSamlUtil.BEARER_METHOD,
                                        makeSubjectConfirmationData(recipient, expirationTime, inResponseTo))),
                        makeConditions(now, expirationTime, makeAudienceRestriction(audience)),
                        makeAuthnStatement(now, AuthnContext.IP_PASSWORD_AUTHN_CTX),
                        makeAttributeStatement(groupsAttribute)));
    }

    public HttpHandler getSingleSignOnHandler() {
        return ssoHandler;
    }

    private class SsoHandler implements HttpHandler {
        @Override
        public void handle(HttpExchange ex) throws IOException {
            if (!"GET".equals(ex.getRequestMethod())) {
                HttpExchanges.cannedRespond(ex, HttpURLConnection.HTTP_BAD_METHOD, Translation.HTTP_BAD_METHOD);
                return;
            }
            if (!ex.getRequestURI().getPath().equals(ex.getHttpContext().getPath())) {
                HttpExchanges.cannedRespond(ex, HttpURLConnection.HTTP_NOT_FOUND, Translation.HTTP_NOT_FOUND);
                return;
            }
            // Setup SAML context.
            SAMLMessageContext<AuthnRequest, Response, NameID> context = OpenSamlUtil.makeSamlMessageContext();
            context.setLocalEntityId(metadata.getLocalEntity().getEntityID());
            context.setLocalEntityMetadata(metadata.getLocalEntity());
            context.setLocalEntityRole(IDPSSODescriptor.DEFAULT_ELEMENT_NAME);
            context.setLocalEntityRoleMetadata(
                    getFirst(metadata.getLocalEntity().getRoleDescriptors(IDPSSODescriptor.DEFAULT_ELEMENT_NAME)));
            context.setOutboundMessageIssuer(metadata.getLocalEntity().getEntityID());
            context.setOutboundSAMLMessageSigningCredential(cred);

            context.setInboundMessageTransport(new HttpExchangeInTransportAdapter(ex));
            // Decode request.
            try {
                new RequestUriRedirectDeflateDecoder(HttpExchanges.getRequestUri(ex)).decode(context);
            } catch (MessageDecodingException e) {
                log.log(Level.INFO, "Error decoding message", e);
                HttpExchanges.cannedRespond(ex, HttpURLConnection.HTTP_BAD_REQUEST,
                        Translation.HTTP_BAD_REQUEST_ERROR_DECODING);
                return;
            } catch (SecurityException e) {
                log.log(Level.WARNING, "Security error while decoding message", e);
                HttpExchanges.cannedRespond(ex, HttpURLConnection.HTTP_BAD_REQUEST,
                        Translation.HTTP_BAD_REQUEST_SECURITY_ERROR);
                return;
            }

            Endpoint peerEndpoint = selectEndpoint(context);
            if (peerEndpoint == null) {
                log.log(Level.INFO, "Error decoding message: could not determine peerEndpoint");
                HttpExchanges.cannedRespond(ex, HttpURLConnection.HTTP_BAD_REQUEST,
                        Translation.HTTP_BAD_REQUEST_ERROR_DECODING);
                return;
            }
            context.setPeerEntityEndpoint(peerEndpoint);

            adaptor.authenticateUser(ex, new AuthnCallback(context));
        }

        private Endpoint selectEndpoint(SAMLMessageContext<AuthnRequest, ?, ?> context) {
            AuthnResponseEndpointSelector selector = new AuthnResponseEndpointSelector();
            selector.setEndpointType(AssertionConsumerService.DEFAULT_ELEMENT_NAME);
            selector.getSupportedIssuerBindings().add(SAMLConstants.SAML2_POST_BINDING_URI);

            String peerEntityId = context.getInboundMessageIssuer();
            EntityDescriptor entityDescriptor = null;
            RoleDescriptor roleDescriptor = null;
            // TODO(ejona): Support additional peer entities other than a single GSA.
            if (peerEntityId != null && peerEntityId.equals(metadata.getPeerEntity().getEntityID())) {
                entityDescriptor = metadata.getPeerEntity();
                roleDescriptor = getFirst(
                        entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME));
            } else {
                log.log(Level.INFO, "Unknown Peer Entity Id: {0}", peerEntityId);
            }

            selector.setSamlRequest(context.getInboundSAMLMessage());
            selector.setEntityMetadata(entityDescriptor);
            selector.setEntityRoleMetadata(roleDescriptor);

            return selector.selectEndpoint();
        }

        private <V> V getFirst(List<V> list) {
            return list.isEmpty() ? null : list.get(0);
        }
    }

    private class AuthnCallback implements AuthnAuthority.Callback {
        private final SAMLMessageContext<AuthnRequest, Response, NameID> context;

        public AuthnCallback(SAMLMessageContext<AuthnRequest, Response, NameID> context) {
            this.context = context;
        }

        @Override
        public void userAuthenticated(HttpExchange ex, AuthnIdentity identity) throws IOException {
            respond(ex, context, identity);
        }
    }

    private static class RequestUriRedirectDeflateDecoder extends HTTPRedirectDeflateDecoder {
        private final String requestUri;

        /**
         * @param requestUri the URI the client used to make the request
         */
        public RequestUriRedirectDeflateDecoder(URI requestUri) {
            try {
                // Remove query parameters from URI.
                requestUri = new URI(requestUri.getScheme(), requestUri.getAuthority(), requestUri.getPath(), null,
                        null);
            } catch (URISyntaxException e) {
                throw new IllegalStateException(e);
            }
            this.requestUri = requestUri.toASCIIString();
        }

        @Override
        protected String getActualReceiverEndpointURI(SAMLMessageContext messageContext) {
            // This method in HTTPRedirectDeflateDecoder is hard-coded for use with
            // HttpServletRequestAdapter only, which we aren't using.
            return requestUri;
        }
    }
}