io.nitor.api.backend.lambda.LambdaHandler.java Source code

Java tutorial

Introduction

Here is the source code for io.nitor.api.backend.lambda.LambdaHandler.java

Source

/**
 * Copyright 2017-2019 Nitor Creations Oy
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.nitor.api.backend.lambda;

import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent.ProxyRequestContext;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent.RequestIdentity;
import io.undertow.util.PathTemplate;
import io.undertow.util.PathTemplateMatcher;
import io.undertow.util.PathTemplateMatcher.PathMatchResult;
import io.vertx.codegen.annotations.Nullable;
import io.vertx.core.Handler;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.RoutingContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvocationType;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.nio.ByteBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.CodingErrorAction;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpHeaderNames.USER_AGENT;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_GATEWAY;
import static io.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
import static io.nitor.api.backend.cache.CacheHelpers.tryToCacheContent;
import static io.nitor.api.backend.util.Helpers.getRemoteAddress;
import static io.nitor.api.backend.util.Helpers.resolveCredentialsProvider;
import static io.nitor.api.backend.util.Helpers.resolveRegion;
import static java.lang.Boolean.TRUE;
import static java.lang.String.join;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.mapping;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

public class LambdaHandler implements Handler<RoutingContext> {
    private static final Logger logger = LogManager.getLogger(LambdaHandler.class);
    private static final Pattern charsetPattern = Pattern.compile("(?i)\\bcharset=\\s*\"?([^\\s;\"]*)");

    private final LambdaAsyncClient lambdaCl;
    private final int routeLength;
    private final PathTemplateMatcher<Entry<String, String>> pathTemplateMatcher;

    public LambdaHandler(JsonObject conf, int routeLength) {
        this.routeLength = routeLength;

        Region region = resolveRegion(conf);
        lambdaCl = LambdaAsyncClient.builder().region(region).credentialsProvider(resolveCredentialsProvider(conf))
                .build();
        pathTemplateMatcher = new PathTemplateMatcher<>();
        for (Object next : conf.getJsonArray("paths")) {
            if (next instanceof JsonObject) {
                JsonObject nextObj = (JsonObject) next;
                String lambdaFunction = nextObj.getString("function");
                String qualifier = conf.getString("qualifier", "$LATEST");
                Entry<String, String> value = new SimpleImmutableEntry<>(lambdaFunction, qualifier);
                pathTemplateMatcher.add(PathTemplate.create(nextObj.getString("template")), value);
            }
        }
    }

    @Override
    public void handle(RoutingContext ctx) {
        HttpServerRequest sreq = ctx.request();
        final String path = normalizePath(sreq.path(), routeLength);
        if (path == null) {
            ctx.response().setStatusCode(NOT_FOUND.code()).end();
            return;
        }
        HttpServerResponse sres = ctx.response();
        PathMatchResult<Entry<String, String>> matchRes = pathTemplateMatcher.match(path);
        final String lambdaFunction, qualifier;
        if (matchRes == null) {
            logger.error("No matching path template");
            sres.setStatusCode(BAD_GATEWAY.code());
            return;
        } else {
            lambdaFunction = matchRes.getValue().getKey();
            qualifier = matchRes.getValue().getValue();
        }
        sreq.bodyHandler(new Handler<Buffer>() {
            @Override
            public void handle(Buffer event) {
                byte[] body = event.getBytes();
                APIGatewayProxyRequestEvent reqObj = new APIGatewayProxyRequestEvent();
                /*
                * Handle body
                */
                String bodyObjStr = null;
                boolean isBase64Encoded = true;
                if (body != null && body.length > 0) {
                    String ct = sreq.getHeader("content-type").toLowerCase();
                    if (ct.startsWith("text/") || ct.startsWith("application/json")
                            || (ct.indexOf("charset=") > 0)) {
                        String charset = "utf-8";
                        if (ct.indexOf("charset=") > 0) {
                            charset = getCharsetFromContentType(ct);
                        }
                        try {
                            bodyObjStr = Charset.forName(charset).newDecoder()
                                    .onMalformedInput(CodingErrorAction.REPORT)
                                    .onUnmappableCharacter(CodingErrorAction.REPORT).decode(ByteBuffer.wrap(body))
                                    .toString();
                            isBase64Encoded = false;
                        } catch (CharacterCodingException e) {
                            logger.error("Decoding body failed", e);
                        }
                    }
                    if (bodyObjStr == null) {
                        bodyObjStr = Base64.getEncoder().encodeToString(body);
                    }
                    reqObj = reqObj.withBody(bodyObjStr).withIsBase64Encoded(isBase64Encoded);
                }
                Map<String, List<String>> headerMultivalue = sreq.headers().entries().stream()
                        .collect(toMap(Entry::getKey, x -> sreq.headers().getAll(x.getKey())));
                Map<String, String> headerValue = sreq.headers().entries().stream()
                        .collect(toMap(Entry::getKey, Entry::getValue));

                /*
                * Handle request context
                */
                RequestIdentity reqId = new RequestIdentity().withSourceIp(getRemoteAddress(ctx))
                        .withUserAgent(sreq.getHeader(USER_AGENT));
                if (ctx.user() != null) {
                    reqId.withUser(ctx.user().principal().toString());
                }
                ProxyRequestContext reqCtx = new ProxyRequestContext()
                        .withPath(sreq.path().substring(0, routeLength)).withHttpMethod(sreq.method().toString())
                        .withIdentity(reqId);
                reqObj = reqObj.withMultiValueHeaders(headerMultivalue).withHeaders(headerValue)
                        .withHttpMethod(sreq.method().toString()).withPath(sreq.path()).withResource(path)
                        .withQueryStringParameters(splitQuery(sreq.query()))
                        .withMultiValueQueryStringParameters(splitMultiValueQuery(sreq.query()))
                        .withPathParameters(matchRes.getParameters()).withRequestContext(reqCtx);
                String reqStr = JsonObject.mapFrom(reqObj).toString();
                byte[] sendBody = reqStr.getBytes(UTF_8);
                InvokeRequest req = InvokeRequest.builder().invocationType(InvocationType.REQUEST_RESPONSE)
                        .functionName(lambdaFunction).qualifier(qualifier).payload(SdkBytes.fromByteArray(sendBody))
                        .build();
                logger.info("Calling lambda " + lambdaFunction + ":" + qualifier);
                logger.debug("Payload: " + reqStr);
                CompletableFuture<InvokeResponse> respFuture = lambdaCl.invoke(req);
                respFuture.whenComplete((iresp, err) -> {
                    if (iresp != null) {
                        try {
                            String payload = iresp.payload().asString(UTF_8);
                            JsonObject resp = new JsonObject(payload);
                            int statusCode = resp.getInteger("statusCode");
                            sres.setStatusCode(statusCode);
                            for (Entry<String, Object> next : resp.getJsonObject("headers").getMap().entrySet()) {
                                sres.putHeader(next.getKey(), next.getValue().toString());
                            }
                            String respBody = resp.getString("body");
                            byte[] bodyArr = new byte[0];
                            if (body != null && !respBody.isEmpty()) {
                                if (TRUE.equals(resp.getBoolean("isBase64Encoded"))) {
                                    bodyArr = Base64.getDecoder().decode(body);
                                } else {
                                    bodyArr = respBody.getBytes(UTF_8);
                                }
                            }
                            sres.putHeader(CONTENT_LENGTH, String.valueOf(bodyArr.length));
                            Buffer buffer = Buffer.buffer(bodyArr);
                            tryToCacheContent(ctx, buffer);
                            sres.write(buffer);
                        } catch (Throwable t) {
                            logger.error("Error processing lambda request", t);
                            if (!sres.headWritten()) {
                                sres.setStatusCode(BAD_GATEWAY.code());
                                sres.putHeader(CONTENT_TYPE, "application/json");
                                Buffer response = Buffer.buffer(new LambdaErrorResponse(t).toString());
                                sres.putHeader(CONTENT_LENGTH, String.valueOf(response.length()));
                                sres.write(response);
                            }
                        } finally {
                            sres.end();
                        }
                    } else {
                        logger.error("Error processing lambda request", err);
                        sres.setStatusCode(BAD_GATEWAY.code());
                        sres.putHeader(CONTENT_TYPE, "application/json");
                        Buffer response = Buffer.buffer(new LambdaErrorResponse(err).toString());
                        sres.putHeader(CONTENT_LENGTH, String.valueOf(response.length()));
                        sres.end(response);
                    }
                });
            }
        });
    }

    private String normalizePath(@Nullable String path, int routeLength) {
        if (path == null) {
            return null;
        } else if (path.contains("../")) {
            return null;
        } else {
            path = path.substring(routeLength);
            if (!path.startsWith("/")) {
                return "/" + path;
            } else {
                return path;
            }
        }
    }

    public static String getCharsetFromContentType(String contentType) {
        if (contentType == null) {
            return null;
        }
        Matcher m = charsetPattern.matcher(contentType);
        if (m.find()) {
            return m.group(1).trim();
        }
        return null;
    }

    public static Map<String, String> splitQuery(String queryString) {
        if (queryString == null || queryString.isEmpty()) {
            return null;
        }
        return splitMultiValueQuery(queryString).entrySet().stream()
                .collect(toMap(Entry::getKey, x -> join(",", x.getValue())));
    }

    public static Map<String, List<String>> splitMultiValueQuery(String queryString) {
        if (queryString == null || queryString.isEmpty()) {
            return null;
        }
        return Arrays.stream(queryString.split("&")).map(LambdaHandler::splitQueryParameter)
                .collect(groupingBy(Entry::getKey, LinkedHashMap::new, mapping(Entry::getValue, toList())));
    }

    public static SimpleImmutableEntry<String, String> splitQueryParameter(String it) {
        final int idx = it.indexOf("=");
        final String key = idx > 0 ? it.substring(0, idx) : it;
        final String value = idx > 0 && it.length() > idx + 1 ? it.substring(idx + 1) : "";
        return new SimpleImmutableEntry<>(key, value);
    }

    static class LambdaErrorResponse {
        public final String errorMessage;
        public final String errorType;
        public final List<String> stackTrace;

        public LambdaErrorResponse(Throwable t) {
            this.errorMessage = t.getMessage();
            this.errorType = t.getClass().getName();
            StringWriter sw = new StringWriter();
            PrintWriter pw = new PrintWriter(sw);
            t.printStackTrace(pw);
            String[] stack = sw.getBuffer().toString().split("\n");
            stackTrace = Arrays.asList(stack);
        }

        public String toString() {
            return JsonObject.mapFrom(this).toString();
        }
    }
}