io.nitor.api.backend.NitorBackend.java Source code

Java tutorial

Introduction

Here is the source code for io.nitor.api.backend.NitorBackend.java

Source

/**
 * Copyright 2016-2019 Nitor Creations Oy, Jonas Berlin
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.nitor.api.backend;

import io.nitor.api.backend.auth.SetupAzureAdConnectAuth;
import io.nitor.api.backend.auth.SimpleConfigAuthProvider;
import io.nitor.api.backend.cache.CacheHandler;
import io.nitor.api.backend.js.InlineJS;
import io.nitor.api.backend.lambda.LambdaHandler;
import io.nitor.api.backend.msgraph.GraphQueryHandler;
import io.nitor.api.backend.proxy.Proxy.ProxyException;
import io.nitor.api.backend.proxy.SetupProxy;
import io.nitor.api.backend.routing.ServiceRouterBuilder;
import io.nitor.api.backend.s3.S3Handler;
import io.nitor.api.backend.session.CookieSessionHandler;
import io.nitor.api.backend.tls.SetupHttpServerOptions;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.http.HttpClient;
import io.vertx.core.http.HttpClientOptions;
import io.vertx.core.http.HttpServerOptions;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.handler.AuthHandler;
import io.vertx.ext.web.handler.BasicAuthHandler;
import io.vertx.ext.web.handler.CookieHandler;
import io.vertx.ext.web.handler.StaticHandler;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import javax.net.ssl.SSLPeerUnverifiedException;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import static com.nitorcreations.core.utils.KillProcess.killProcessUsingPort;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.nitor.api.backend.auth.SetupAzureAdConnectAuth.SECRET_DATA_PREFIX;
import static io.nitor.api.backend.session.CookieConverter.secureCookie;
import static io.nitor.api.backend.util.Helpers.REMOTE_ADDRESS;
import static io.nitor.api.backend.util.Helpers.getUriHostName;
import static io.nitor.api.backend.util.Helpers.parseForwardedHeaders;
import static io.nitor.api.backend.util.Helpers.toBytes;
import static io.vertx.core.buffer.Buffer.buffer;
import static io.vertx.core.http.HttpHeaders.CACHE_CONTROL;
import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
import static io.vertx.core.http.HttpHeaders.EXPIRES;
import static io.vertx.core.http.HttpMethod.GET;
import static java.lang.Integer.getInteger;
import static java.lang.System.exit;
import static java.lang.System.getProperty;
import static java.lang.System.getenv;
import static java.lang.System.setProperty;
import static java.nio.file.Files.exists;
import static java.util.Locale.ROOT;
import static java.util.concurrent.TimeUnit.DAYS;
import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.concurrent.TimeUnit.SECONDS;

public class NitorBackend extends AbstractVerticle {
    private static final int listenPort = getInteger("port", 8443);
    private static final String listenHost = getProperty("host", "0.0.0.0");
    private static Logger logger;

    public static void main(String... args) {
        setProperty("java.nio.channels.spi.SelectorProvider", InheritedChannelSelectorProvider.class.getName());
        setupLogging();
        if (!InheritedChannelSelectorProvider.hasInheritedChannel()) {
            killProcessUsingPort(listenPort);
        }
        try {
            PropertiesLauncher.main(Stream.concat(Stream.of("run", NitorBackend.class.getName()), Stream.of(args))
                    .toArray(String[]::new));
        } catch (Exception ex) {
            logger.fatal("Startup failure", ex);
            exit(3);
        }
    }

    private static void setupLogging() {
        if (getenv("LOG4J_CONFIGURATION_FILE") == null && getProperty("log4j.configurationFile") == null
                && exists(Paths.get("log4j2.xml"))) {
            setProperty("log4j.configurationFile", "log4j2.xml");
        }
        setProperty("java.util.logging.manager", "org.apache.logging.log4j.jul.LogManager");
        setProperty("vertx.logger-delegate-factory-class-name", "io.vertx.core.logging.Log4j2LogDelegateFactory");
        logger = LogManager.getLogger(NitorBackend.class);
    }

    @Override
    public void start() {
        vertx.exceptionHandler(e -> logger.error("Fallback exception handler got", e));

        HttpServerOptions httpServerOptions = SetupHttpServerOptions.createHttpServerOptions(config());

        Router router = Router.router(vertx);

        HttpClientOptions clientOptions = new HttpClientOptions();
        clientOptions.setConnectTimeout((int) SECONDS.toMillis(5));
        clientOptions.setIdleTimeout((int) SECONDS.toMillis(15));
        clientOptions.setSsl(true);
        HttpClient httpClient = vertx.createHttpClient(clientOptions);

        Map<String, String> injectedResponseHeaders = new HashMap<>();
        for (Entry<String, Object> defaultHeader : config().getJsonObject("defaultHeaders")) {
            injectedResponseHeaders.put(defaultHeader.getKey().toLowerCase(), defaultHeader.getValue().toString());
        }

        String publicURI = config().getString("publicURI",
                "http" + (httpServerOptions.isSsl() ? "s" : "") + "://localhost:" + listenPort);
        if (publicURI.endsWith("/")) {
            publicURI = publicURI.substring(0, publicURI.length() - 1);
        }
        publicURI = publicURI.toLowerCase(ROOT);

        boolean isOrigReqHttps = httpServerOptions.isSsl() || publicURI.startsWith("https:");
        boolean trustPreviousProxy = config().getBoolean("trustPreviousProxy",
                publicURI.startsWith("https:") && !httpServerOptions.isSsl());

        router.route().handler(new AccessLogHandler()::handle);
        router.route().handler(routingContext -> {
            HttpServerResponse resp = routingContext.response();
            if (isOrigReqHttps) {
                resp.putHeader("strict-transport-security", "max-age=31536000; includeSubDomains");
            }
            if (trustPreviousProxy) {
                String origHost = parseForwardedHeaders(routingContext.request().headers());
                if (origHost != null) {
                    routingContext.put(REMOTE_ADDRESS, origHost);
                }
            }
            if (!injectedResponseHeaders.isEmpty()) {
                routingContext.addHeadersEndHandler(v -> {
                    for (Entry<String, String> header : injectedResponseHeaders.entrySet()) {
                        if (!resp.headers().contains(header.getKey())) {
                            resp.putHeader(header.getKey(), header.getValue());
                        }
                    }
                });
            }
            routingContext.next();
        });

        router.get("/healthCheck").handler(routingContext -> routingContext.response().setStatusCode(200).end());

        router.get("/certCheck").handler(routingContext -> {
            String resp;
            try {
                resp = "Certs: " + Arrays.toString(routingContext.request().peerCertificateChain());
            } catch (SSLPeerUnverifiedException e) {
                resp = "No client certs available:" + e.getMessage();
            }
            routingContext.response().setChunked(true).putHeader(CONTENT_TYPE, "text/plain; charset=utf-8")
                    .write(resp).end();
        });

        JsonObject clientAuth = config().getJsonObject("clientAuth");
        if (clientAuth != null) {
            if (null != clientAuth.getString("clientChain")) {
                router.route(clientAuth.getString("route", "/*")).handler(routingContext -> {
                    try {
                        routingContext.request().peerCertificateChain();
                        routingContext.next();
                    } catch (SSLPeerUnverifiedException e) {
                        routingContext.response().setStatusCode(FORBIDDEN.code());
                        routingContext.response().end();
                        logger.info("Rejected request that was missing valid client certificate from ip {}: {}",
                                routingContext.request().remoteAddress(), e.getMessage());
                    }
                });
            }
        }

        boolean virtualHost = config().getBoolean("virtualHost", false);
        if (virtualHost) {
            router.route().handler(ctx -> {
                ctx.put("host", getUriHostName(ctx.request().host()));
                ctx.next();
            });
        }

        JsonObject sessionConf = config().getJsonObject("session");
        CookieSessionHandler sessionHandler = sessionConf != null ? new CookieSessionHandler(sessionConf) : null;
        if (sessionHandler != null) {
            router.route().handler(CookieHandler.create());

            router.get("/proxyLogout").handler(routingContext -> {
                routingContext.cookies()
                        .forEach(cookie -> secureCookie(cookie, (int) DAYS.toSeconds(30)).setValue(""));
                routingContext.response().putHeader(CACHE_CONTROL, "no-cache, no-store, must-revalidate")
                        .putHeader(EXPIRES, "0").putHeader(CONTENT_TYPE, "text/plain; charset=utf-8")
                        .end("Logged out", "UTF-8");
            });
        }

        JsonObject adAuth = config().getJsonObject("adAuth");
        if (adAuth != null) {
            JsonObject openIdConfig = adAuth.getJsonObject("openIdConfig");
            if (openIdConfig == null || !openIdConfig.containsKey("authorization_endpoint")
                    || !openIdConfig.containsKey("token_endpoint")) {
                String configURI = adAuth.getString("configurationURI");
                try {
                    logger.info("Fetching configuration from " + configURI);
                    URL url = URI.create(configURI).toURL();
                    openIdConfig = new JsonObject(buffer(toBytes(url.openStream())));
                } catch (Exception e) {
                    RuntimeException ex = new RuntimeException("Failed to fetch open id config from " + configURI,
                            e);
                    logger.fatal("adAuth config failure", ex);
                    throw ex;
                }
                logger.info(
                        "To speed up startup please define \"adAuth\": {\"openIdConfig\": {\"authorization_endpoint\": \""
                                + openIdConfig.getString("authorization_endpoint") + "\", \"token_endpoint\": \""
                                + openIdConfig.getString("token_endpoint") + "\" } }");
            }
            adAuth.put("openIdConfig", openIdConfig);
            SetupAzureAdConnectAuth.setupAzureAd(adAuth, router, publicURI, virtualHost, sessionHandler,
                    httpClient);
        }

        JsonObject basicAuth = config().getJsonObject("basicAuth");
        if (basicAuth != null) {
            AuthHandler basicAuthHandler = BasicAuthHandler.create(
                    new SimpleConfigAuthProvider(basicAuth.getJsonObject("users")),
                    basicAuth.getString("realm", "nitor"));
            router.route(basicAuth.getString("route", "/*")).handler(basicAuthHandler);
        }

        if (sessionHandler != null) {
            router.get("/cookieCheck").handler(routingContext -> {
                Map<String, String> headers = sessionHandler.getSessionData(routingContext);
                StringBuilder sb = new StringBuilder(2048);
                if (headers == null) {
                    sb.append("No valid session");
                } else {
                    headers.forEach((key, value) -> {
                        sb.append(key).append('=');
                        if (key.startsWith(SECRET_DATA_PREFIX))
                            sb.append("<secret>");
                        else
                            sb.append(value);
                        sb.append('\n');
                    });
                }
                routingContext.response().putHeader(CONTENT_TYPE, "text/plain; charset=utf-8").end(sb.toString());
            });
        }

        JsonArray customizeConf = config().getJsonArray("customize");
        if (customizeConf != null) {
            customizeConf.forEach(c -> {
                JsonObject conf = (JsonObject) c;
                InlineJS inlineJs = new InlineJS(vertx, conf.getString("jsFile", "custom.js"));
                router.route(conf.getString("route")).handler(ctx -> {
                    inlineJs.call("handleRequest", ctx.request(), ctx);
                    ctx.addHeadersEndHandler((v) -> inlineJs.call("handleResponse", ctx.response(), ctx));
                    ctx.next();
                });
            });
        }

        setupServices(config(), httpServerOptions, router, new ServiceRouterBuilder(), httpClient, sessionHandler,
                adAuth, isOrigReqHttps);

        router.route().failureHandler(routingContext -> {
            String error = "ERROR";
            int statusCode = routingContext.statusCode();
            Throwable t = routingContext.failure();
            logger.info("Handling failure statusCode=" + statusCode, t);
            HttpServerResponse resp = routingContext.response();
            if (resp.ended()) {
                return;
            }
            if (resp.headWritten()) {
                resp.end();
                routingContext.request().connection().close();
                return;
            }
            if (t != null) {
                if (t instanceof ProxyException) {
                    statusCode = ((ProxyException) t).statusCode;
                }
                error = "ERROR: " + t.toString();
            }
            resp.setStatusCode(statusCode != -1 ? statusCode : INTERNAL_SERVER_ERROR.code());
            resp.headers().set("Content-Type", "text/plain; charset=UTF-8");
            resp.headers().set("Content-Length", Integer.toString(error.length()));
            resp.end(error);
        });

        vertx.createHttpServer(httpServerOptions).requestHandler(router).listen(listenPort, listenHost);
    }

    private void setupServices(JsonObject configRoot, HttpServerOptions httpServerOptions, Router router,
            ServiceRouterBuilder routeBuilder, HttpClient httpClient, CookieSessionHandler sessionHandler,
            JsonObject adAuth, boolean isOrigReqHttps) {
        JsonArray services = configRoot.getJsonArray("services", new JsonArray());
        services.forEach(s -> {
            JsonObject service = (JsonObject) s;
            String type = service.getString("type", "<missing>");
            String logMsg = "Setting up service '" + type + "' on route '" + service.getString("route") + "'";
            switch (type) {
            case "proxy":
                setupProxy(service, routeBuilder, httpServerOptions, isOrigReqHttps);
                break;
            case "static":
                setupStaticFiles(service, routeBuilder);
                break;
            case "s3":
                setupS3(service, routeBuilder);
                break;
            case "lambda":
                setupLambda(service, routeBuilder);
                break;
            case "graph":
                setupGraph(service, routeBuilder, adAuth, sessionHandler, httpClient);
                break;
            case "virtualHost":
                String virtualHost = service.getString("host");
                logMsg += " for virtual host '" + virtualHost + "'";
                setupServices(service, httpServerOptions, null, routeBuilder.virtualHostHandler(virtualHost),
                        httpClient, sessionHandler, adAuth, isOrigReqHttps);
                break;
            case "cache":
                setupCache(service, routeBuilder);
                break;
            default: {
                RuntimeException ex = new RuntimeException("No support for service '" + type + "'");
                logger.fatal("service config failure", ex);
                throw ex;
            }
            }
            logger.info(logMsg);
        });
        if (router != null) {
            routeBuilder.registerHandlers(router);
        }
    }

    private void setupGraph(JsonObject service, ServiceRouterBuilder routerBuilder, JsonObject adAuth,
            CookieSessionHandler sessionHandler, HttpClient httpClient) {
        String routePrefix = service.getString("route");
        String cleanedRoute = cleanRoute(routePrefix);
        routerBuilder.route(routePrefix,
                new GraphQueryHandler(service, cleanedRoute.length(), adAuth, sessionHandler, httpClient), null);
    }

    private void setupS3(JsonObject service, ServiceRouterBuilder routerBuilder) {
        String routePrefix = service.getString("route");
        String cleanedRoute = cleanRoute(routePrefix);
        routerBuilder.route(routePrefix, new S3Handler(vertx, service, cleanedRoute.length()), null);
    }

    private void setupLambda(JsonObject service, ServiceRouterBuilder routerBuilder) {
        String routePrefix = service.getString("route");
        String cleanedRoute = cleanRoute(routePrefix);
        routerBuilder.route(routePrefix, new LambdaHandler(service, cleanedRoute.length()), null);
    }

    private void setupCache(JsonObject service, ServiceRouterBuilder routerBuilder) {
        String routePrefix = service.getString("route");
        routerBuilder.route(routePrefix, new CacheHandler(service, vertx.fileSystem()).build(), null);
    }

    private String cleanRoute(String routePrefix) {
        String cleanedRoute = routePrefix;
        if (cleanedRoute.endsWith("*")) {
            cleanedRoute = cleanedRoute.substring(0, cleanedRoute.length() - 1);
        }
        if (cleanedRoute.endsWith("/")) {
            cleanedRoute = cleanedRoute.substring(0, cleanedRoute.length() - 1);
        }
        return cleanedRoute;
    }

    private void setupStaticFiles(JsonObject service, ServiceRouterBuilder routerBuilder) {
        int cacheTimeout = service.getInteger("cacheTimeout", (int) MINUTES.toSeconds(30));
        routerBuilder.route(GET, service.getString("route"),
                StaticHandler.create().setFilesReadOnly(service.getBoolean("readOnly", true))
                        .setAllowRootFileSystemAccess(true).setWebRoot(service.getString("dir", "."))
                        .setCachingEnabled(cacheTimeout > 0).setCacheEntryTimeout(cacheTimeout),
                null);
        String staticPathConfig = service.getString("staticPaths");
        if (staticPathConfig != null) {
            Pattern staticPaths = Pattern.compile(staticPathConfig);
            String routePrefix = service.getString("route");
            String cleanRoute = cleanRoute(routePrefix);
            routerBuilder.route(GET, routePrefix, ctx -> {
                String normalised = ctx.normalisedPath().substring(1);
                if (!staticPaths.matcher(normalised).matches()) {
                    ctx.reroute(cleanRoute + "/index.html");
                } else {
                    ctx.next();
                }
            }, null);
        }

    }

    private void setupProxy(JsonObject service, ServiceRouterBuilder routerBuilder,
            HttpServerOptions httpServerOptions, boolean isOrigReqHttps) {
        SetupProxy.setupProxy(vertx, routerBuilder, service, httpServerOptions, isOrigReqHttps);
    }
}