com.netflix.genie.web.security.oauth2.pingfederate.PingFederateRemoteTokenServices.java Source code

Java tutorial

Introduction

Here is the source code for com.netflix.genie.web.security.oauth2.pingfederate.PingFederateRemoteTokenServices.java

Source

/*
 *
 *  Copyright 2016 Netflix, Inc.
 *
 *     Licensed under the Apache License, Version 2.0 (the "License");
 *     you may not use this file except in compliance with the License.
 *     You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License.
 *
 */
package com.netflix.genie.web.security.oauth2.pingfederate;

import com.netflix.spectator.api.Id;
import com.netflix.spectator.api.Registry;
import com.netflix.spectator.api.Timer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.boot.autoconfigure.security.oauth2.resource.ResourceServerProperties;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpRequestExecution;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AccessTokenConverter;
import org.springframework.security.oauth2.provider.token.RemoteTokenServices;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RestTemplate;

import javax.validation.constraints.NotNull;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * A remote token services extension for Ping Federate based IDPs.
 *
 * @author tgianos
 * @since 3.0.0
 */
@Slf4j
public class PingFederateRemoteTokenServices extends RemoteTokenServices {

    protected static final String TOKEN_NAME_KEY = "token";
    protected static final String CLIENT_ID_KEY = "client_id";
    protected static final String CLIENT_SECRET_KEY = "client_secret";
    protected static final String GRANT_TYPE_KEY = "grant_type";
    protected static final String ERROR_KEY = "error";
    protected static final String SCOPE_KEY = "scope";
    protected static final String GRANT_TYPE = "urn:pingidentity.com:oauth2:grant_type:validate_bearer";
    protected static final String AUTHENTICATION_TIMER_NAME = "genie.security.oauth2.pingFederate.authentication.timer";
    protected static final String API_TIMER_NAME = "genie.security.oauth2.pingFederate.api.timer";

    private final AccessTokenConverter converter;
    private RestTemplate localRestTemplate;

    private final String checkTokenEndpointUrl;
    private final String clientId;
    private final String clientSecret;

    // Metrics
    private final Id tokenValidationError;
    private final Timer authenticationTimer;
    private final Timer pingFederateAPITimer;

    /**
     * Constructor.
     *
     * @param serverProperties The properties of the resource server (Genie)
     * @param converter        The access token converter to use
     * @param registry         The metrics registry to use
     */
    public PingFederateRemoteTokenServices(@NotNull final ResourceServerProperties serverProperties,
            @NotNull final AccessTokenConverter converter, @NotNull final Registry registry) {
        super();
        this.tokenValidationError = registry
                .createId("genie.security.oauth2.pingFederate.tokenValidation.error.rate");
        this.authenticationTimer = registry.timer(AUTHENTICATION_TIMER_NAME);
        this.pingFederateAPITimer = registry.timer(API_TIMER_NAME);
        final HttpComponentsClientHttpRequestFactory factory = new HttpComponentsClientHttpRequestFactory();
        factory.setConnectTimeout(2000);
        factory.setReadTimeout(10000);
        final RestTemplate restTemplate = new RestTemplate(factory);
        final List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();
        interceptors
                .add((final HttpRequest request, final byte[] body, final ClientHttpRequestExecution execution) -> {
                    final long start = System.nanoTime();
                    try {
                        return execution.execute(request, body);
                    } finally {
                        pingFederateAPITimer.record(System.nanoTime() - start, TimeUnit.NANOSECONDS);
                    }
                });
        restTemplate.setInterceptors(interceptors);
        restTemplate.setErrorHandler(new DefaultResponseErrorHandler() {
            // Ignore 400
            @Override
            public void handleError(final ClientHttpResponse response) throws IOException {
                final int errorCode = response.getRawStatusCode();
                registry.counter(tokenValidationError.withTag("status", Integer.toString(errorCode))).increment();
                if (response.getRawStatusCode() != HttpStatus.BAD_REQUEST.value()) {
                    super.handleError(response);
                }
            }
        });

        this.setRestTemplate(restTemplate);

        this.checkTokenEndpointUrl = serverProperties.getTokenInfoUri();
        this.clientId = serverProperties.getClientId();
        this.clientSecret = serverProperties.getClientSecret();

        Assert.state(StringUtils.isNotBlank(this.checkTokenEndpointUrl), "Check Endpoint URL is required");
        Assert.state(StringUtils.isNotBlank(this.clientId), "Client ID is required");
        Assert.state(StringUtils.isNotBlank(this.clientSecret), "Client secret is required");

        log.debug("checkTokenEndpointUrl = {}", this.checkTokenEndpointUrl);
        log.debug("clientId = {}", this.clientId);
        log.debug("clientSecret = {}", this.clientSecret);

        this.converter = converter;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public OAuth2Authentication loadAuthentication(final String accessToken)
            throws AuthenticationException, InvalidTokenException {
        final long start = System.nanoTime();
        try {
            final MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
            formData.add(TOKEN_NAME_KEY, accessToken);
            formData.add(CLIENT_ID_KEY, this.clientId);
            formData.add(CLIENT_SECRET_KEY, this.clientSecret);
            formData.add(GRANT_TYPE_KEY, GRANT_TYPE);

            final Map<String, Object> map = this.postForMap(this.checkTokenEndpointUrl, formData);

            if (map.containsKey(ERROR_KEY)) {
                final String error = map.get(ERROR_KEY).toString();
                log.debug("Validating the token produced an error: {}", error);
                throw new InvalidTokenException(error);
            }

            Assert.state(map.containsKey(CLIENT_ID_KEY), "Client id must be present in response from auth server");
            Assert.state(map.containsKey(SCOPE_KEY), "No scopes included in response from authentication server");
            this.convertScopes(map);
            final OAuth2Authentication authentication = this.converter.extractAuthentication(map);
            log.info("User {} authenticated with authorities {}", authentication.getPrincipal(),
                    authentication.getAuthorities());
            return authentication;
        } finally {
            final long finished = System.nanoTime();
            this.authenticationTimer.record(finished - start, TimeUnit.NANOSECONDS);
        }
    }

    /**
     * Set the rest operations to use.
     *
     * @param restTemplate The rest operations to use. Not null.
     */
    protected void setRestTemplate(@NotNull final RestTemplate restTemplate) {
        super.setRestTemplate(restTemplate);
        this.localRestTemplate = restTemplate;
    }

    private Map<String, Object> postForMap(final String path, final MultiValueMap<String, String> formData) {
        final HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
        @SuppressWarnings("rawtypes")
        final Map map = this.localRestTemplate
                .exchange(path, HttpMethod.POST, new HttpEntity<>(formData, headers), Map.class).getBody();
        @SuppressWarnings("unchecked")
        final Map<String, Object> result = map;
        return result;
    }

    private void convertScopes(final Map<String, Object> oauth2Map) {
        final Object scopesObject = oauth2Map.get(SCOPE_KEY);
        if (scopesObject == null) {
            throw new InvalidTokenException("Scopes were null");
        }

        if (scopesObject instanceof String) {
            final String scopes = (String) scopesObject;
            if (StringUtils.isBlank(scopes)) {
                throw new InvalidTokenException("No scopes found unable to authenticate");
            }

            oauth2Map.put(SCOPE_KEY, Arrays.asList(StringUtils.split(scopes, ' ')));
        } else {
            throw new InvalidTokenException("Scopes was not a String");
        }
    }
}