org.cloudfoundry.reactor.tokenprovider.AbstractUaaTokenProvider.java Source code

Java tutorial

Introduction

Here is the source code for org.cloudfoundry.reactor.tokenprovider.AbstractUaaTokenProvider.java

Source

/*
 * Copyright 2013-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.cloudfoundry.reactor.tokenprovider;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.netty.util.AsciiString;
import org.cloudfoundry.Nullable;
import org.cloudfoundry.reactor.ConnectionContext;
import org.cloudfoundry.reactor.TokenProvider;
import org.cloudfoundry.reactor.util.ErrorPayloadMapper;
import org.cloudfoundry.reactor.util.JsonCodec;
import org.cloudfoundry.reactor.util.NetworkLogging;
import org.cloudfoundry.reactor.util.UserAgent;
import org.cloudfoundry.uaa.UaaException;
import org.immutables.value.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.ReplayProcessor;
import reactor.ipc.netty.http.client.HttpClientRequest;
import reactor.ipc.netty.http.client.HttpClientResponse;

import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Base64;
import java.util.Date;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import java.util.function.Function;

import static io.netty.handler.codec.http.HttpHeaderNames.AUTHORIZATION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED;
import static io.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;

/**
 * An abstract base class for all token providers that interact with the UAA.  It encapsulates the logic to refresh the token before expiration.
 */
public abstract class AbstractUaaTokenProvider implements TokenProvider {

    private static final Logger LOGGER = LoggerFactory.getLogger("cloudfoundry-client.token");

    private static final String ACCESS_TOKEN = "access_token";

    private static final String AUTHORIZATION_ENDPOINT = "authorization_endpoint";

    private static final String REFRESH_TOKEN = "refresh_token";

    private static final String TOKEN_TYPE = "token_type";

    private static final ZoneId UTC = ZoneId.of("UTC");

    private final ConcurrentMap<ConnectionContext, Mono<String>> accessTokens = new ConcurrentHashMap<>(1);

    private final ConcurrentMap<ConnectionContext, RefreshToken> refreshTokenStreams = new ConcurrentHashMap<>(1);

    private final ConcurrentMap<ConnectionContext, Mono<String>> refreshTokens = new ConcurrentHashMap<>(1);

    /**
     * The client id. Defaults to {@code cf}.
     */
    @Value.Default
    public String getClientId() {
        return "cf";
    }

    /**
     * The client secret. Defaults to {@code ""}.
     */
    @Value.Default
    public String getClientSecret() {
        return "";
    }

    /**
     * Returns a {@link Flux} of refresh tokens for a connection
     *
     * @param connectionContext A {@link ConnectionContext} to be used to identity which connection the refresh tokens be retrieved for
     * @return a {@link Flux} that emits the last token on subscribe and new refresh tokens as they are negotiated
     */
    public Flux<String> getRefreshTokens(ConnectionContext connectionContext) {
        return getRefreshTokenStream(connectionContext).processor;
    }

    @Override
    public final Mono<String> getToken(ConnectionContext connectionContext) {
        return this.accessTokens.computeIfAbsent(connectionContext, this::token);
    }

    @Override
    public void invalidate(ConnectionContext connectionContext) {
        this.accessTokens.put(connectionContext, token(connectionContext));
    }

    /**
     * The identity zone subdomain
     */
    @Nullable
    abstract String getIdentityZoneSubdomain();

    /**
     * Transforms a {@code Mono} in order to make a request to negotiate an access token
     *
     * @param outbound the {@link Mono} to transform to perform the token request
     */
    abstract Mono<Void> tokenRequestTransformer(Mono<HttpClientRequest> outbound);

    private static HttpClientRequest addContentType(HttpClientRequest request) {
        return request.header(CONTENT_TYPE, APPLICATION_X_WWW_FORM_URLENCODED);
    }

    private static HttpClientRequest disableChunkedTransfer(HttpClientRequest request) {
        return request.chunkedTransfer(false);
    }

    private static HttpClientRequest disableFailOnError(HttpClientRequest request) {
        return request.failOnClientError(false).failOnServerError(false);
    }

    private static String extractAccessToken(Map<String, String> payload) {
        String accessToken = payload.get(ACCESS_TOKEN);

        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Access Token: {}", accessToken);

            parseToken(accessToken).ifPresent(claims -> {
                LOGGER.debug("Access Token Issued At:  {} UTC", toLocalDateTime(claims.getIssuedAt()));
                LOGGER.debug("Access Token Expires At: {} UTC", toLocalDateTime(claims.getExpiration()));
            });
        }

        return String.format("%s %s", payload.get(TOKEN_TYPE), accessToken);
    }

    private static String getTokenUri(String root, String identityZoneId) {
        UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(root);

        if (identityZoneId != null) {
            builder.host(String.format("%s.%s", identityZoneId, builder.build().getHost()));
        }

        return builder.pathSegment("oauth", "token").build().encode().toUriString();
    }

    private static Optional<Claims> parseToken(String token) {
        try {
            String jws = token.substring(0, token.lastIndexOf('.') + 1);
            return Optional.of(Jwts.parser().parseClaimsJwt(jws).getBody());
        } catch (Exception e) {
            return Optional.empty();
        }
    }

    private static LocalDateTime toLocalDateTime(Date date) {
        return LocalDateTime.from(date.toInstant().atZone(UTC));
    }

    private HttpClientRequest addAuthorization(HttpClientRequest request) {
        String encoded = Base64.getEncoder()
                .encodeToString(new AsciiString(getClientId()).concat(":").concat(getClientSecret()).toByteArray());
        return request.header(AUTHORIZATION, String.format("Basic %s", encoded));
    }

    private Consumer<Map<String, String>> extractRefreshToken(ConnectionContext connectionContext) {
        return payload -> Optional.ofNullable(payload.get(REFRESH_TOKEN)).ifPresent(refreshToken -> {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("Refresh Token: {}", refreshToken);

                parseToken(refreshToken).ifPresent(claims -> {
                    LOGGER.debug("Refresh Token Issued At:  {} UTC", toLocalDateTime(claims.getIssuedAt()));
                    LOGGER.debug("Refresh Token Expires At: {} UTC", toLocalDateTime(claims.getExpiration()));
                });
            }

            this.refreshTokens.put(connectionContext, Mono.just(refreshToken));
            getRefreshTokenStream(connectionContext).sink.next(refreshToken);
        });
    }

    @SuppressWarnings("unchecked")
    private Function<Mono<HttpClientResponse>, Mono<String>> extractTokens(ConnectionContext connectionContext) {
        return inbound -> inbound.transform(JsonCodec.decode(connectionContext.getObjectMapper(), Map.class))
                .map(payload -> (Map<String, String>) payload).doOnNext(extractRefreshToken(connectionContext))
                .map(AbstractUaaTokenProvider::extractAccessToken);
    }

    private RefreshToken getRefreshTokenStream(ConnectionContext connectionContext) {
        return this.refreshTokenStreams.computeIfAbsent(connectionContext, c -> new RefreshToken());
    }

    private Mono<HttpClientResponse> primaryToken(ConnectionContext connectionContext) {
        return requestToken(connectionContext, this::tokenRequestTransformer);
    }

    private Mono<HttpClientResponse> refreshToken(ConnectionContext connectionContext, String refreshToken) {
        return requestToken(connectionContext, refreshTokenGrantTokenRequestTransformer(refreshToken))
                .onErrorResume(
                        t -> t instanceof UaaException && ((UaaException) t).getStatusCode() == UNAUTHORIZED.code(),
                        t -> Mono.empty());
    }

    private Function<Mono<HttpClientRequest>, Mono<Void>> refreshTokenGrantTokenRequestTransformer(
            String refreshToken) {
        return outbound -> outbound
                .flatMap(
                        request -> request
                                .sendForm(form -> form.multipart(false).attr("client_id", getClientId())
                                        .attr("client_secret", getClientSecret())
                                        .attr("grant_type", "refresh_token").attr("refresh_token", refreshToken))
                                .then());
    }

    private Mono<HttpClientResponse> requestToken(ConnectionContext connectionContext,
            Function<Mono<HttpClientRequest>, Mono<Void>> tokenRequestTransformer) {
        return connectionContext.getRootProvider().getRoot(AUTHORIZATION_ENDPOINT, connectionContext)
                .map(root -> getTokenUri(root, getIdentityZoneSubdomain()))
                .flatMap(uri -> connectionContext.getHttpClient().post(uri,
                        request -> Mono.just(request).map(AbstractUaaTokenProvider::disableChunkedTransfer)
                                .map(AbstractUaaTokenProvider::disableFailOnError).map(this::addAuthorization)
                                .map(UserAgent::addUserAgent).map(AbstractUaaTokenProvider::addContentType)
                                .map(JsonCodec::addDecodeHeaders).transform(tokenRequestTransformer))
                        .doOnSubscribe(NetworkLogging.post(uri)).transform(NetworkLogging.response(uri)))
                .transform(ErrorPayloadMapper.uaa(connectionContext.getObjectMapper()));
    }

    private Mono<String> token(ConnectionContext connectionContext) {
        Mono<String> cached = this.refreshTokens.getOrDefault(connectionContext, Mono.empty())
                .flatMap(refreshToken -> refreshToken(connectionContext, refreshToken)
                        .doOnSubscribe(s -> LOGGER.debug("Negotiating using refresh token")))
                .switchIfEmpty(primaryToken(connectionContext)
                        .doOnSubscribe(s -> LOGGER.debug("Negotiating using token provider")))
                .transform(ErrorPayloadMapper.fallback()).transform(extractTokens(connectionContext));

        return connectionContext.getCacheDuration().map(cached::cache).orElseGet(cached::cache).checkpoint();
    }

    private static final class RefreshToken {

        private ReplayProcessor<String> processor = ReplayProcessor.cacheLast();

        private FluxSink<String> sink = this.processor.sink();

    }

}