io.nitor.api.backend.auth.SetupAzureAdConnectAuth.java Source code

Java tutorial

Introduction

Here is the source code for io.nitor.api.backend.auth.SetupAzureAdConnectAuth.java

Source

/**
 * Copyright 2017-2019 Nitor Creations Oy
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.nitor.api.backend.auth;

import io.netty.handler.codec.http.HttpResponseStatus;
import io.nitor.api.backend.session.CookieSessionHandler;
import io.nitor.api.backend.util.JsonPointer;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClient;
import io.vertx.core.http.HttpClientResponse;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.Cookie;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static io.netty.handler.codec.http.HttpHeaderNames.ACCEPT;
import static io.netty.handler.codec.http.HttpHeaderNames.AUTHORIZATION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpHeaderNames.WWW_AUTHENTICATE;
import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON;
import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpResponseStatus.SEE_OTHER;
import static io.netty.handler.codec.http.HttpResponseStatus.TEMPORARY_REDIRECT;
import static io.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
import static io.nitor.api.backend.msgraph.GraphQueryHandler.GRAPH_ACCESS_TOKEN_KEY;
import static io.nitor.api.backend.util.Helpers.forceHttps;
import static io.nitor.api.backend.util.Helpers.getUriHostName;
import static io.nitor.api.backend.util.Helpers.getUriHostNamePort;
import static io.nitor.api.backend.util.Helpers.replaceHost;
import static io.nitor.api.backend.util.Helpers.replaceHostAndPort;
import static io.nitor.api.backend.util.Helpers.urlEncode;
import static io.vertx.core.http.HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS;
import static io.vertx.core.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN;
import static io.vertx.core.http.HttpHeaders.ALLOW;
import static io.vertx.core.http.HttpHeaders.CACHE_CONTROL;
import static io.vertx.core.http.HttpHeaders.EXPIRES;
import static io.vertx.core.http.HttpHeaders.LOCATION;
import static io.vertx.core.http.HttpHeaders.SET_COOKIE;
import static io.vertx.core.http.HttpMethod.GET;
import static java.util.Collections.singletonMap;
import static java.util.Optional.ofNullable;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.toSet;

public class SetupAzureAdConnectAuth {
    private static final Logger logger = LogManager.getLogger(SetupAzureAdConnectAuth.class);
    public static final String SECRET_DATA_PREFIX = "s-";
    static final String UNAUTHORIZED_PATH = "/auth-failed";
    static final String FORBIDDEN_PATH = "/not-authorized";
    static final String PROXY_AUTH_REDIRECT_BEFORE = "/proxyAuthBefore";
    static final String PROXY_AUTH_REDIRECT_AFTER = "/proxyAuthAfter";

    public static void setupAzureAd(JsonObject adAuth, Router router, String publicURI, boolean virtualHost,
            CookieSessionHandler sessionHandler, HttpClient httpClient) {
        final String callbackPath = adAuth.getString("callbackPath", "/oidc/callback");
        String redirectUri = publicURI + callbackPath;
        adAuth.put("redirectUri", redirectUri);

        String path = adAuth.getString("route", "/*");

        List<GraphQuery> graphQueries = new ArrayList<>();
        JsonArray queryNodes = adAuth.getJsonArray("graphQueries");
        if (queryNodes == null) {
            graphQueries.add(new GraphQuery(adAuth, "https://graph.microsoft.com/beta/me?$expand=memberOf"));
        } else {
            queryNodes.stream().map(JsonObject.class::cast).map(GraphQuery::new).forEach(graphQueries::add);
        }
        Set<String> forbiddenHeaders = graphQueries.stream().flatMap(gq -> gq.headerMappings.keySet().stream())
                .collect(toSet());
        logger.info("Graph queries: "
                + graphQueries.stream().map(gq -> gq.graphQueryURI).collect(Collectors.joining(", ")));
        logger.info("Headers: " + forbiddenHeaders);

        HashMap<String, Pattern> requiredHeaderMatchers = new HashMap<>();
        adAuth.getJsonObject("requiredHeaders", new JsonObject()).forEach(mapping -> requiredHeaderMatchers
                .put(mapping.getKey(), Pattern.compile(mapping.getValue().toString())));

        RedirectTokenService redirectTokenService = new RedirectTokenService(sessionHandler.getCookieConverter());

        Handler<RoutingContext> securityHandler = authHandler(adAuth, forbiddenHeaders, requiredHeaderMatchers,
                publicURI, virtualHost, sessionHandler, redirectUri, redirectTokenService);

        router.get(FORBIDDEN_PATH).handler(ctx -> errorWithLogoutLink(ctx, FORBIDDEN));
        router.get(UNAUTHORIZED_PATH).handler(ctx -> errorWithLogoutLink(ctx, UNAUTHORIZED));

        router.get(callbackPath).handler(validateAuthCallback(adAuth, httpClient, sessionHandler, graphQueries,
                redirectUri, redirectTokenService));

        if (virtualHost) {
            router.options(PROXY_AUTH_REDIRECT_AFTER).handler(SetupAzureAdConnectAuth::optionsHandler);
            router.route(PROXY_AUTH_REDIRECT_AFTER).handler(ctx -> {
                // phase 3: executed when returning to virtual domain with cookie and original url inside token
                // -> jump to original locatin and set the cookie
                String token = ctx.request().getParam("t");
                Map<String, String> params = redirectTokenService.getParameters(ctx, token);
                if (params == null) {
                    ctx.reroute(GET, UNAUTHORIZED_PATH);
                    logger.warn("phase3: Could not decrypt parameters from 't'");
                    return;
                }
                String originalUrl = params.get("u");
                String originalHost = getUriHostName(originalUrl);
                String host = getUriHostName(ctx.request().host());
                if (originalHost != null && originalHost.equals(host)) {
                    ctx.response().setStatusCode(TEMPORARY_REDIRECT.code()).putHeader(LOCATION, originalUrl)
                            .putHeader(SET_COOKIE, params.get("c")).end();
                } else {
                    logger.warn("phase3: original host from cookie " + originalHost
                            + " does not match request host " + host);
                    ctx.reroute(GET, FORBIDDEN_PATH);
                }
            });
        }

        router.route(path).handler(securityHandler);

        if (virtualHost) {
            router.options(PROXY_AUTH_REDIRECT_BEFORE).handler(SetupAzureAdConnectAuth::optionsHandler);
            router.route(PROXY_AUTH_REDIRECT_BEFORE).handler(securityHandler);
            router.route(PROXY_AUTH_REDIRECT_BEFORE).handler(ctx -> {
                // phase 2: executed when returning from authentication server with valid cookie
                // -> jump to original virtual host domain and pass the original url and auth cookie inside token
                String token = ctx.request().getParam("t");
                Map<String, String> params = redirectTokenService.getParameters(ctx, token);
                if (params == null) {
                    ctx.reroute(GET, UNAUTHORIZED_PATH);
                    logger.warn("phase2: Could not decrypt parameters from 't'");
                    return;
                }
                String originalUrl = params.get("u");
                String originalHost = getUriHostNamePort(originalUrl);
                if (originalUrl == null || !originalUrl.startsWith("https://")) {
                    ctx.reroute(GET, FORBIDDEN_PATH);
                    logger.warn(
                            "phase2: original url from cookie " + originalUrl + " does not start with https://");
                    return;
                }
                Cookie cookie = sessionHandler.getAuthCookie(ctx.cookies());
                params.put("c", cookie.encode());
                String newToken = redirectTokenService.createToken(ctx, params);
                StringBuilder sb = new StringBuilder();
                sb.append("https://").append(originalHost).append(PROXY_AUTH_REDIRECT_AFTER).append("?t=")
                        .append(urlEncode(newToken));
                ctx.response().setStatusCode(TEMPORARY_REDIRECT.code()).putHeader(LOCATION, sb).end();
            });
        }
    }

    private static void errorWithLogoutLink(RoutingContext ctx, HttpResponseStatus status) {
        String html = "<html><head><title>" + status.reasonPhrase() + "</title></head>" + "<body><h1>"
                + status.reasonPhrase() + "</h1><p>Try <a href=\"/proxyLogout\">clearing cookies</a></p></body>";
        ctx.response().setStatusCode(status.code()).putHeader(CONTENT_TYPE, "text/html; charset=utf-8").end(html);
    }

    private static void optionsHandler(RoutingContext ctx) {
        ctx.response().putHeader(ALLOW, "GET, OPTIONS").putHeader(ACCESS_CONTROL_ALLOW_METHODS, "GET, OPTIONS")
                .putHeader(ACCESS_CONTROL_ALLOW_ORIGIN, "*").setStatusCode(200).end();
    }

    private static Handler<RoutingContext> authHandler(JsonObject adAuth, Set<String> forbiddenHeaders,
            HashMap<String, Pattern> requiredHeaderMatchers, String publicURI, boolean virtualHosting,
            CookieSessionHandler sessionHandler, String redirectUri, RedirectTokenService redirectTokenService) {
        String publicHost = getUriHostName(publicURI);
        return ctx -> {
            Optional<Map<String, String>> headers = ofNullable(sessionHandler.getSessionData(ctx));
            if (headers.isPresent()) {
                MultiMap h = ctx.request().headers();
                forbiddenHeaders.forEach(h::remove);
                headers.get().entrySet().stream().filter(e -> !e.getKey().startsWith(SECRET_DATA_PREFIX))
                        .forEach(e -> h.set(e.getKey(), e.getValue()));

                if (!requiredHeaderMatchers.entrySet().stream()
                        .allMatch(e -> headerMatches(h.get(e.getKey()), e.getValue()))) {
                    logger.info("Not authorised to view resource '" + ctx.request().path() + "' with session data: "
                            + headers.get());

                    ctx.reroute(GET, FORBIDDEN_PATH);
                    return;
                }

                ctx.next();
                return;
            }

            String publicURIWithoutProtocol = getUriHostName(publicURI);

            String host = getUriHostName(ctx.request().host());
            if (virtualHosting && !publicURIWithoutProtocol.equals(host)) {
                // phase 1: executed iff authentication cookie is missing && the browser is not on the auth domain but on a virtual domain
                // -> jump to auth domain and pass the current url inside token
                String currentUri = forceHttps(replaceHostAndPort(ctx.request().absoluteURI(), host));
                String token = redirectTokenService.createToken(ctx, singletonMap("u", currentUri));
                ctx.response()
                        .setStatusCode((ctx.request().method() == GET ? TEMPORARY_REDIRECT : SEE_OTHER).code()) // ask browser to turn POST etc into GET when redirecting
                        .putHeader(CACHE_CONTROL, "no-cache, no-store, must-revalidate").putHeader(EXPIRES, "0")
                        .putHeader(LOCATION, publicURI + PROXY_AUTH_REDIRECT_BEFORE + "?t=" + urlEncode(token))
                        .end();
                return;
            }

            StringBuilder sb = new StringBuilder();
            String currentUri = forceHttps(replaceHost(ctx.request().absoluteURI(), publicHost));
            sb.append(adAuth.getJsonObject("openIdConfig").getString("authorization_endpoint"))
                    .append("?domain_hint=organizations&response_type=code&response_mode=query")
                    .append("&client_id=").append(urlEncode(adAuth.getString("clientId"))).append("&redirect_uri=")
                    .append(urlEncode(redirectUri)).append("&scope=").append(urlEncode(adAuth.getString("scope")))
                    //.append("&login_hint=").append(urlEncode(previousKnownUserName)) -- could try to fetch it from expired session cookie?
                    //.append("&prompt=").append("login") -- force login - maybe do if IP is from different country?
                    .append("&state=")
                    .append(urlEncode(redirectTokenService.createToken(ctx, singletonMap("a", currentUri))));
            ctx.response().setStatusCode(TEMPORARY_REDIRECT.code()).putHeader(LOCATION, sb)
                    .putHeader(CACHE_CONTROL, "no-cache, no-store, must-revalidate").putHeader(EXPIRES, "0").end();
        };
    }

    private static Handler<RoutingContext> validateAuthCallback(JsonObject adAuth, HttpClient httpClient,
            CookieSessionHandler sessionHandler, List<GraphQuery> graphQueries, String redirectUri,
            RedirectTokenService redirectTokenService) {
        return ctx -> finalizeAuthentication(ctx, adAuth, httpClient, sessionHandler, graphQueries, redirectUri,
                redirectTokenService);
    }

    private static void finalizeAuthentication(RoutingContext ctx, JsonObject adAuth, HttpClient httpClient,
            CookieSessionHandler sessionHandler, List<GraphQuery> graphQueries, String redirectUri,
            RedirectTokenService redirectTokenService) {
        Map<String, String> params = redirectTokenService.getParameters(ctx, ctx.request().getParam("state"));
        if (params == null || params.get("a") == null) {
            logger.error("Missing state parameter");
            ctx.reroute(GET, UNAUTHORIZED_PATH);
            return;
        }
        String originalUrl = params.get("a");
        String code = ctx.request().getParam("code");
        String graphScopes = adAuth.getString("scope");
        Buffer form = Buffer
                .buffer("code=" + urlEncode(code) + "&client_id=" + urlEncode(adAuth.getString("clientId"))
                        + "&scope=" + urlEncode(graphScopes) + "&grant_type=authorization_code" + "&client_secret="
                        + urlEncode(adAuth.getString("clientSecret")) + "&redirect_uri=" + urlEncode(redirectUri));
        String tokenUrl = adAuth.getJsonObject("openIdConfig").getString("token_endpoint");
        logger.debug("Requesting graph access token from " + tokenUrl + " with [ " + form + "]");
        httpClient.postAbs(tokenUrl).putHeader(ACCEPT, APPLICATION_JSON)
                .putHeader(CONTENT_TYPE, APPLICATION_X_WWW_FORM_URLENCODED)
                .putHeader(CONTENT_LENGTH, String.valueOf(form.length())).setTimeout(SECONDS.toMillis(10))
                .exceptionHandler(err -> {
                    logger.error("Failed to fetch graph access token", err);
                    ctx.reroute(GET, UNAUTHORIZED_PATH);
                }).handler(resp -> processGraphTokenResponse(resp, ctx, httpClient, sessionHandler, graphQueries,
                        originalUrl))
                .end(form);
    }

    static void processGraphTokenResponse(HttpClientResponse resp, RoutingContext ctx, HttpClient httpClient,
            CookieSessionHandler sessionHandler, List<GraphQuery> graphQueries, String originalUrl) {
        if (resp.statusCode() != OK.code()) {
            resp.bodyHandler(body -> {
                logger.warn("Failed to fetch graph access token: " + resp.statusMessage() + " - "
                        + resp.getHeader(WWW_AUTHENTICATE) + " ::: " + body);
                ctx.reroute(GET, UNAUTHORIZED_PATH);
            });
            return;
        }
        resp.bodyHandler(body -> {
            JsonObject json = body.toJsonObject();
            String token = json.getString("access_token");
            String refreshToken = json.getString("refresh_token");
            // clean out sensitive stuff
            json.put("access_token", "<censored>");
            json.put("refresh_token", "<censored>");

            logger.debug("Got graph access response: {}", json);
            final AtomicInteger pendingRequests = new AtomicInteger(graphQueries.size());
            final Map<String, String> sessionData = new HashMap<>();
            ofNullable(refreshToken).ifPresent(t -> sessionData.put(GRAPH_ACCESS_TOKEN_KEY, t));
            for (GraphQuery query : graphQueries) {
                String clientRequestId = UUID.randomUUID().toString();
                logger.debug("Requesting " + query.graphQueryURI + "[" + clientRequestId + "]");
                httpClient.getAbs(query.graphQueryURI).putHeader(AUTHORIZATION, "Bearer " + token)
                        .putHeader(ACCEPT, APPLICATION_JSON).putHeader("client-request-id", clientRequestId)
                        .setTimeout(SECONDS.toMillis(10)).exceptionHandler(err -> {
                            if (pendingRequests.getAndSet(-1) != -1) {
                                logger.error("Failed to fetch user information [" + clientRequestId + "]", err);
                                ctx.reroute(GET, UNAUTHORIZED_PATH);
                            }
                        }).handler(r -> processMicrosoftUserInformation(r, ctx, sessionHandler,
                                query.headerMappings, originalUrl, pendingRequests, sessionData, clientRequestId))
                        .end();
            }
        });
    }

    static void processMicrosoftUserInformation(HttpClientResponse resp, RoutingContext ctx,
            CookieSessionHandler sessionHandler, Map<String, String> headerMappings, String originalUrl,
            AtomicInteger pendingRequests, Map<String, String> sessionData, String clientRequestId) {
        if (resp.statusCode() != OK.code()) {
            if (pendingRequests.getAndSet(-1) != -1) {
                logger.warn("Failed to fetch graph information [" + clientRequestId + "]: " + resp.statusMessage()
                        + " - " + resp.getHeader(WWW_AUTHENTICATE));
                ctx.reroute(GET, UNAUTHORIZED_PATH);
                return;
            }
        }

        resp.bodyHandler(body -> {
            JsonObject response = body.toJsonObject();
            logger.debug("Got graph response [" + clientRequestId + "]: {}", response);
            synchronized (sessionData) {
                headerMappings.forEach((header, pointer) -> ofNullable(JsonPointer.fetch(response, pointer))
                        .ifPresent(val -> sessionData.put(header, val)));
            }
            if (pendingRequests.decrementAndGet() == 0) {
                sessionHandler.setSessionData(ctx, sessionData);
                ctx.response().setStatusCode(TEMPORARY_REDIRECT.code())
                        .putHeader(CACHE_CONTROL, "no-cache, no-store, must-revalidate").putHeader(EXPIRES, "0")
                        .putHeader(LOCATION, originalUrl).end();
            }
        });
    }

    static boolean headerMatches(String header, Pattern pattern) {
        return header != null && pattern.matcher(header).matches();
    }

    public static class GraphQuery {
        public String graphQueryURI;
        public final Map<String, String> headerMappings = new HashMap<>();

        public GraphQuery(JsonObject object) {
            this(object, null);
        }

        public GraphQuery(JsonObject object, String defaultQuery) {
            graphQueryURI = object.getString("graphQueryURI", defaultQuery);
            object.getJsonObject("headerMappings", new JsonObject())
                    .forEach(mapping -> headerMappings.put(mapping.getKey(), mapping.getValue().toString()));
        }
    }
}