org.springframework.security.oauth2.client.OAuth2RestTemplate.java Source code

Java tutorial

Introduction

Here is the source code for org.springframework.security.oauth2.client.OAuth2RestTemplate.java

Source

package org.springframework.security.oauth2.client;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.util.Arrays;

import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.security.oauth2.client.http.AccessTokenRequiredException;
import org.springframework.security.oauth2.client.http.OAuth2ErrorHandler;
import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
import org.springframework.security.oauth2.client.token.AccessTokenProvider;
import org.springframework.security.oauth2.client.token.AccessTokenProviderChain;
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsAccessTokenProvider;
import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider;
import org.springframework.security.oauth2.client.token.grant.implicit.ImplicitAccessTokenProvider;
import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordAccessTokenProvider;
import org.springframework.security.oauth2.common.AuthenticationScheme;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;

/**
 * Rest template that is able to make OAuth2-authenticated REST requests with the credentials of the provided resource.
 * 
 * @author Ryan Heaton
 * @author Dave Syer
 */
public class OAuth2RestTemplate extends RestTemplate implements OAuth2RestOperations {

    private final OAuth2ProtectedResourceDetails resource;

    private AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(
            Arrays.<AccessTokenProvider>asList(new AuthorizationCodeAccessTokenProvider(),
                    new ImplicitAccessTokenProvider(), new ResourceOwnerPasswordAccessTokenProvider(),
                    new ClientCredentialsAccessTokenProvider()));

    private OAuth2ClientContext context;

    private boolean retryBadAccessTokens = true;

    private OAuth2RequestAuthenticator authenticator = new DefaultOAuth2RequestAuthenticator();

    public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource) {
        this(resource, new DefaultOAuth2ClientContext());
    }

    public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource, OAuth2ClientContext context) {
        super();
        if (resource == null) {
            throw new IllegalArgumentException("An OAuth2 resource must be supplied.");
        }

        this.resource = resource;
        this.context = context;
        setErrorHandler(new OAuth2ErrorHandler(resource));
    }

    /**
     * Strategy for extracting an Authorization header from an access token and the request details. Defaults to the
     * simple form "TOKEN_TYPE TOKEN_VALUE".
     * 
     * @param authenticator the authenticator to use
     */
    public void setAuthenticator(OAuth2RequestAuthenticator authenticator) {
        this.authenticator = authenticator;
    }

    /**
     * Flag to determine whether a request that has an existing access token, and which then leads to an
     * AccessTokenRequiredException should be retried (immediately, once). Useful if the remote server doesn't recognize
     * an old token which is stored in the client, but is happy to re-grant it.
     * 
     * @param retryBadAccessTokens the flag to set (default true)
     */
    public void setRetryBadAccessTokens(boolean retryBadAccessTokens) {
        this.retryBadAccessTokens = retryBadAccessTokens;
    }

    @Override
    public void setErrorHandler(ResponseErrorHandler errorHandler) {
        if (!(errorHandler instanceof OAuth2ErrorHandler)) {
            errorHandler = new OAuth2ErrorHandler(errorHandler, resource);
        }
        super.setErrorHandler(errorHandler);
    }

    @Override
    public OAuth2ProtectedResourceDetails getResource() {
        return resource;
    }

    @Override
    protected ClientHttpRequest createRequest(URI uri, HttpMethod method) throws IOException {

        OAuth2AccessToken accessToken = getAccessToken();

        AuthenticationScheme authenticationScheme = resource.getAuthenticationScheme();
        if (AuthenticationScheme.query.equals(authenticationScheme)
                || AuthenticationScheme.form.equals(authenticationScheme)) {
            uri = appendQueryParameter(uri, accessToken);
        }

        ClientHttpRequest req = super.createRequest(uri, method);

        if (AuthenticationScheme.header.equals(authenticationScheme)) {
            authenticator.authenticate(resource, getOAuth2ClientContext(), req);
        }
        return req;

    }

    @Override
    protected <T> T doExecute(URI url, HttpMethod method, RequestCallback requestCallback,
            ResponseExtractor<T> responseExtractor) throws RestClientException {
        OAuth2AccessToken accessToken = context.getAccessToken();
        RuntimeException rethrow = null;
        try {
            return super.doExecute(url, method, requestCallback, responseExtractor);
        } catch (AccessTokenRequiredException e) {
            rethrow = e;
        } catch (OAuth2AccessDeniedException e) {
            rethrow = e;
        } catch (InvalidTokenException e) {
            // Don't reveal the token value in case it is logged
            rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId());
        }
        if (accessToken != null && retryBadAccessTokens) {
            context.setAccessToken(null);
            try {
                return super.doExecute(url, method, requestCallback, responseExtractor);
            } catch (InvalidTokenException e) {
                // Don't reveal the token value in case it is logged
                rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId());
            }
        }
        throw rethrow;
    }

    /**
     * @return the client id for this resource.
     */
    private String getClientId() {
        return resource.getClientId();
    }

    /**
     * Acquire or renew an access token for the current context if necessary. This method will be called automatically
     * when a request is executed (and the result is cached), but can also be called as a standalone method to
     * pre-populate the token.
     * 
     * @return an access token
     */
    public OAuth2AccessToken getAccessToken() throws UserRedirectRequiredException {

        OAuth2AccessToken accessToken = context.getAccessToken();

        if (accessToken == null || accessToken.isExpired()) {
            try {
                accessToken = acquireAccessToken(context);
            } catch (UserRedirectRequiredException e) {
                context.setAccessToken(null); // No point hanging onto it now
                accessToken = null;
                String stateKey = e.getStateKey();
                if (stateKey != null) {
                    Object stateToPreserve = e.getStateToPreserve();
                    if (stateToPreserve == null) {
                        stateToPreserve = "NONE";
                    }
                    context.setPreservedState(stateKey, stateToPreserve);
                }
                throw e;
            }
        }
        return accessToken;
    }

    /**
     * @return the context for this template
     */
    public OAuth2ClientContext getOAuth2ClientContext() {
        return context;
    }

    protected OAuth2AccessToken acquireAccessToken(OAuth2ClientContext oauth2Context)
            throws UserRedirectRequiredException {

        AccessTokenRequest accessTokenRequest = oauth2Context.getAccessTokenRequest();
        if (accessTokenRequest == null) {
            throw new AccessTokenRequiredException(
                    "No OAuth 2 security context has been established. Unable to access resource '"
                            + this.resource.getId() + "'.",
                    resource);
        }

        // Transfer the preserved state from the (longer lived) context to the current request.
        String stateKey = accessTokenRequest.getStateKey();
        if (stateKey != null) {
            accessTokenRequest.setPreservedState(oauth2Context.removePreservedState(stateKey));
        }

        OAuth2AccessToken existingToken = oauth2Context.getAccessToken();
        if (existingToken != null) {
            accessTokenRequest.setExistingToken(existingToken);
        }

        OAuth2AccessToken accessToken = null;
        accessToken = accessTokenProvider.obtainAccessToken(resource, accessTokenRequest);
        if (accessToken == null || accessToken.getValue() == null) {
            throw new IllegalStateException(
                    "Access token provider returned a null access token, which is illegal according to the contract.");
        }
        oauth2Context.setAccessToken(accessToken);
        return accessToken;
    }

    protected URI appendQueryParameter(URI uri, OAuth2AccessToken accessToken) {

        try {

            // TODO: there is some duplication with UriUtils here. Probably unavoidable as long as this
            // method signature uses URI not String.
            String query = uri.getRawQuery(); // Don't decode anything here
            String queryFragment = resource.getTokenName() + "="
                    + URLEncoder.encode(accessToken.getValue(), "UTF-8");
            if (query == null) {
                query = queryFragment;
            } else {
                query = query + "&" + queryFragment;
            }

            // first form the URI without query and fragment parts, so that it doesn't re-encode some query string chars
            // (SECOAUTH-90)
            URI update = new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), uri.getPath(),
                    null, null);
            // now add the encoded query string and the then fragment
            StringBuffer sb = new StringBuffer(update.toString());
            sb.append("?");
            sb.append(query);
            if (uri.getFragment() != null) {
                sb.append("#");
                sb.append(uri.getFragment());
            }

            return new URI(sb.toString());

        } catch (URISyntaxException e) {
            throw new IllegalArgumentException("Could not parse URI", e);
        } catch (UnsupportedEncodingException e) {
            throw new IllegalArgumentException("Could not encode URI", e);
        }

    }

    public void setAccessTokenProvider(AccessTokenProvider accessTokenProvider) {
        this.accessTokenProvider = accessTokenProvider;
    }

}