io.nitor.api.backend.msgraph.GraphQueryHandler.java Source code

Java tutorial

Introduction

Here is the source code for io.nitor.api.backend.msgraph.GraphQueryHandler.java

Source

/**
 * Copyright 2018 Nitor Creations Oy
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.nitor.api.backend.msgraph;

import io.nitor.api.backend.msgraph.GraphSessionTokenService.TokenData;
import io.nitor.api.backend.session.CookieSessionHandler;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.http.HttpClient;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.json.JsonObject;
import io.vertx.core.streams.Pump;
import io.vertx.ext.web.RoutingContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

import static io.netty.handler.codec.http.HttpHeaderNames.ACCEPT;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCEPT_ENCODING;
import static io.netty.handler.codec.http.HttpHeaderNames.AUTHORIZATION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_GATEWAY;
import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.nitor.api.backend.auth.SetupAzureAdConnectAuth.SECRET_DATA_PREFIX;
import static java.util.Optional.ofNullable;
import static java.util.concurrent.TimeUnit.SECONDS;

public class GraphQueryHandler implements Handler<RoutingContext> {
    private static final Logger logger = LogManager.getLogger(GraphQueryHandler.class);
    private static final String[] allowedRequestHeaders = new String[] { "Prefer", CONTENT_LENGTH.toString(),
            CONTENT_TYPE.toString(), ACCEPT_ENCODING.toString() };

    public static final String GRAPH_ACCESS_TOKEN_KEY = SECRET_DATA_PREFIX + "GRT";

    private final String baseUrl;
    private final GraphSessionTokenService tokenCache;
    private final int routeLength;
    private final CookieSessionHandler sessionHandler;
    private final HttpClient httpClient;

    public GraphQueryHandler(JsonObject conf, int routeLength, JsonObject adAuth,
            CookieSessionHandler sessionHandler, HttpClient httpClient) {
        this.routeLength = routeLength;
        this.sessionHandler = sessionHandler;
        this.httpClient = httpClient;
        String target = conf.getString("target", "https://graph.microsoft.com/beta");
        if (target.endsWith("/"))
            target = target.substring(0, target.length() - 1);
        this.baseUrl = target;
        this.tokenCache = new GraphSessionTokenService(httpClient, adAuth);
        if (!adAuth.getString("scope").contains("offline_access")) {
            throw new IllegalArgumentException("auth scope must contain \"offline_access\"");
        }
    }

    @Override
    public void handle(RoutingContext ctx) {
        HttpServerRequest sreq = ctx.request();
        String path = sreq.path();
        path = path.substring(routeLength);
        if (!path.startsWith("/")) {
            path = '/' + path;
        }
        path = baseUrl + path + paramsOf(ctx.request().absoluteURI());

        Map<String, String> data = sessionHandler.getSessionData(ctx);
        String refreshToken = data.get(GRAPH_ACCESS_TOKEN_KEY);
        Future<TokenData> tokenFuture = tokenCache.getAccessToken(refreshToken);

        HttpServerResponse sres = ctx.response();
        String finalPath = path;
        tokenFuture.setHandler(tokenResult -> {
            if (tokenResult.failed()) {
                sessionHandler.removeCookie(ctx);
                String err = tokenResult.cause().toString();
                logger.error(err);
                sres.setStatusCode(INTERNAL_SERVER_ERROR.code()).end(err);
                return;
            }
            TokenData token = tokenResult.result();
            if (!refreshToken.equals(token.refreshToken)) {
                Map<String, String> newData = new HashMap<>(data);
                newData.put(GRAPH_ACCESS_TOKEN_KEY, token.refreshToken);
                sessionHandler.setSessionData(ctx, newData);
            }
            String clientRequestId = UUID.randomUUID().toString();
            logger.info("Querying " + sreq.method() + " " + finalPath + " [" + clientRequestId + "]");
            HttpClientRequest creq = httpClient.requestAbs(sreq.method(), finalPath)
                    .putHeader(AUTHORIZATION, "Bearer " + token.accessToken).putHeader(ACCEPT, APPLICATION_JSON)
                    .putHeader("client-request-id", clientRequestId).setTimeout(SECONDS.toMillis(20))
                    .exceptionHandler(err -> {
                        logger.error("Graph query failed [" + clientRequestId + "]", err);
                        if (!sres.ended()) {
                            sres.setStatusCode(INTERNAL_SERVER_ERROR.code()).write("Graph query failed: " + err)
                                    .end();
                        }
                    });

            for (String header : allowedRequestHeaders) {
                ofNullable(sreq.getHeader(header)).ifPresent(value -> creq.putHeader(header, value));
            }
            if (sres.headers().getAll("transfer-encoding").stream().anyMatch(v -> v.equals("chunked"))) {
                creq.setChunked(true);
            }

            sres.closeHandler(close -> creq.connection().close());
            creq.handler(cres -> mapResponse(cres, sres, clientRequestId));

            if (sreq.isEnded()) {
                creq.end();
            } else {
                sreq.endHandler(v -> {
                    try {
                        creq.end();
                    } catch (IllegalStateException ex) {
                        // ignore - nothing can be done - the request is already complete/closed - TODO log?
                    }
                });
                Pump resPump = Pump.pump(sreq, creq);
                resPump.start();
            }
        });
    }

    private String paramsOf(String absoluteURI) {
        int idx = absoluteURI.indexOf("?");
        return idx > 0 ? absoluteURI.substring(idx) : "";
    }

    private void mapResponse(HttpClientResponse cres, HttpServerResponse sres, String clientRequestId) {
        cres.exceptionHandler(t -> {
            logger.error("Error processing graph request [" + clientRequestId + "]", t);
            if (!sres.ended()) {
                sres.setStatusCode(BAD_GATEWAY.code());
                sres.end();
            }
        });

        // TODO Together with the client-request-id always log the request-id, timestamp and x-ms-ags-diagnostic from the HTTP response headers

        sres.setStatusCode(cres.statusCode());
        sres.setStatusMessage(cres.statusMessage());

        MultiMap headers = sres.headers();
        cres.headers().forEach(entry -> {
            String key = entry.getKey();
            String lKey = key.toLowerCase();
            if ("server".equals(lKey) || "accept-ranges".equals(lKey) || "transfer-encoding".equals(lKey)
                    || "date".equals(lKey) || "connection".equals(lKey)) {
                return;
            }
            headers.add(key, entry.getValue());
        });

        if (!headers.contains("content-length")) {
            sres.setChunked(true);
        }

        Pump resPump = Pump.pump(cres, sres);
        cres.endHandler(v -> {
            if (!sres.ended())
                sres.end();
        });
        resPump.start();
    }

}