Java tutorial
/** * Copyright 2017-2019 Nitor Creations Oy * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.nitor.api.backend.lambda; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent.ProxyRequestContext; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent.RequestIdentity; import io.undertow.util.PathTemplate; import io.undertow.util.PathTemplateMatcher; import io.undertow.util.PathTemplateMatcher.PathMatchResult; import io.vertx.codegen.annotations.Nullable; import io.vertx.core.Handler; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerRequest; import io.vertx.core.http.HttpServerResponse; import io.vertx.core.json.JsonObject; import io.vertx.ext.web.RoutingContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.model.InvocationType; import software.amazon.awssdk.services.lambda.model.InvokeRequest; import software.amazon.awssdk.services.lambda.model.InvokeResponse; import java.io.PrintWriter; import java.io.StringWriter; import java.nio.ByteBuffer; import java.nio.charset.CharacterCodingException; import java.nio.charset.Charset; import java.nio.charset.CodingErrorAction; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Arrays; import java.util.Base64; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.CompletableFuture; import java.util.regex.Matcher; import java.util.regex.Pattern; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; import static io.netty.handler.codec.http.HttpHeaderNames.USER_AGENT; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_GATEWAY; import static io.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND; import static io.nitor.api.backend.cache.CacheHelpers.tryToCacheContent; import static io.nitor.api.backend.util.Helpers.getRemoteAddress; import static io.nitor.api.backend.util.Helpers.resolveCredentialsProvider; import static io.nitor.api.backend.util.Helpers.resolveRegion; import static java.lang.Boolean.TRUE; import static java.lang.String.join; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.mapping; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; public class LambdaHandler implements Handler<RoutingContext> { private static final Logger logger = LogManager.getLogger(LambdaHandler.class); private static final Pattern charsetPattern = Pattern.compile("(?i)\\bcharset=\\s*\"?([^\\s;\"]*)"); private final LambdaAsyncClient lambdaCl; private final int routeLength; private final PathTemplateMatcher<Entry<String, String>> pathTemplateMatcher; public LambdaHandler(JsonObject conf, int routeLength) { this.routeLength = routeLength; Region region = resolveRegion(conf); lambdaCl = LambdaAsyncClient.builder().region(region).credentialsProvider(resolveCredentialsProvider(conf)) .build(); pathTemplateMatcher = new PathTemplateMatcher<>(); for (Object next : conf.getJsonArray("paths")) { if (next instanceof JsonObject) { JsonObject nextObj = (JsonObject) next; String lambdaFunction = nextObj.getString("function"); String qualifier = conf.getString("qualifier", "$LATEST"); Entry<String, String> value = new SimpleImmutableEntry<>(lambdaFunction, qualifier); pathTemplateMatcher.add(PathTemplate.create(nextObj.getString("template")), value); } } } @Override public void handle(RoutingContext ctx) { HttpServerRequest sreq = ctx.request(); final String path = normalizePath(sreq.path(), routeLength); if (path == null) { ctx.response().setStatusCode(NOT_FOUND.code()).end(); return; } HttpServerResponse sres = ctx.response(); PathMatchResult<Entry<String, String>> matchRes = pathTemplateMatcher.match(path); final String lambdaFunction, qualifier; if (matchRes == null) { logger.error("No matching path template"); sres.setStatusCode(BAD_GATEWAY.code()); return; } else { lambdaFunction = matchRes.getValue().getKey(); qualifier = matchRes.getValue().getValue(); } sreq.bodyHandler(new Handler<Buffer>() { @Override public void handle(Buffer event) { byte[] body = event.getBytes(); APIGatewayProxyRequestEvent reqObj = new APIGatewayProxyRequestEvent(); /* * Handle body */ String bodyObjStr = null; boolean isBase64Encoded = true; if (body != null && body.length > 0) { String ct = sreq.getHeader("content-type").toLowerCase(); if (ct.startsWith("text/") || ct.startsWith("application/json") || (ct.indexOf("charset=") > 0)) { String charset = "utf-8"; if (ct.indexOf("charset=") > 0) { charset = getCharsetFromContentType(ct); } try { bodyObjStr = Charset.forName(charset).newDecoder() .onMalformedInput(CodingErrorAction.REPORT) .onUnmappableCharacter(CodingErrorAction.REPORT).decode(ByteBuffer.wrap(body)) .toString(); isBase64Encoded = false; } catch (CharacterCodingException e) { logger.error("Decoding body failed", e); } } if (bodyObjStr == null) { bodyObjStr = Base64.getEncoder().encodeToString(body); } reqObj = reqObj.withBody(bodyObjStr).withIsBase64Encoded(isBase64Encoded); } Map<String, List<String>> headerMultivalue = sreq.headers().entries().stream() .collect(toMap(Entry::getKey, x -> sreq.headers().getAll(x.getKey()))); Map<String, String> headerValue = sreq.headers().entries().stream() .collect(toMap(Entry::getKey, Entry::getValue)); /* * Handle request context */ RequestIdentity reqId = new RequestIdentity().withSourceIp(getRemoteAddress(ctx)) .withUserAgent(sreq.getHeader(USER_AGENT)); if (ctx.user() != null) { reqId.withUser(ctx.user().principal().toString()); } ProxyRequestContext reqCtx = new ProxyRequestContext() .withPath(sreq.path().substring(0, routeLength)).withHttpMethod(sreq.method().toString()) .withIdentity(reqId); reqObj = reqObj.withMultiValueHeaders(headerMultivalue).withHeaders(headerValue) .withHttpMethod(sreq.method().toString()).withPath(sreq.path()).withResource(path) .withQueryStringParameters(splitQuery(sreq.query())) .withMultiValueQueryStringParameters(splitMultiValueQuery(sreq.query())) .withPathParameters(matchRes.getParameters()).withRequestContext(reqCtx); String reqStr = JsonObject.mapFrom(reqObj).toString(); byte[] sendBody = reqStr.getBytes(UTF_8); InvokeRequest req = InvokeRequest.builder().invocationType(InvocationType.REQUEST_RESPONSE) .functionName(lambdaFunction).qualifier(qualifier).payload(SdkBytes.fromByteArray(sendBody)) .build(); logger.info("Calling lambda " + lambdaFunction + ":" + qualifier); logger.debug("Payload: " + reqStr); CompletableFuture<InvokeResponse> respFuture = lambdaCl.invoke(req); respFuture.whenComplete((iresp, err) -> { if (iresp != null) { try { String payload = iresp.payload().asString(UTF_8); JsonObject resp = new JsonObject(payload); int statusCode = resp.getInteger("statusCode"); sres.setStatusCode(statusCode); for (Entry<String, Object> next : resp.getJsonObject("headers").getMap().entrySet()) { sres.putHeader(next.getKey(), next.getValue().toString()); } String respBody = resp.getString("body"); byte[] bodyArr = new byte[0]; if (body != null && !respBody.isEmpty()) { if (TRUE.equals(resp.getBoolean("isBase64Encoded"))) { bodyArr = Base64.getDecoder().decode(body); } else { bodyArr = respBody.getBytes(UTF_8); } } sres.putHeader(CONTENT_LENGTH, String.valueOf(bodyArr.length)); Buffer buffer = Buffer.buffer(bodyArr); tryToCacheContent(ctx, buffer); sres.write(buffer); } catch (Throwable t) { logger.error("Error processing lambda request", t); if (!sres.headWritten()) { sres.setStatusCode(BAD_GATEWAY.code()); sres.putHeader(CONTENT_TYPE, "application/json"); Buffer response = Buffer.buffer(new LambdaErrorResponse(t).toString()); sres.putHeader(CONTENT_LENGTH, String.valueOf(response.length())); sres.write(response); } } finally { sres.end(); } } else { logger.error("Error processing lambda request", err); sres.setStatusCode(BAD_GATEWAY.code()); sres.putHeader(CONTENT_TYPE, "application/json"); Buffer response = Buffer.buffer(new LambdaErrorResponse(err).toString()); sres.putHeader(CONTENT_LENGTH, String.valueOf(response.length())); sres.end(response); } }); } }); } private String normalizePath(@Nullable String path, int routeLength) { if (path == null) { return null; } else if (path.contains("../")) { return null; } else { path = path.substring(routeLength); if (!path.startsWith("/")) { return "/" + path; } else { return path; } } } public static String getCharsetFromContentType(String contentType) { if (contentType == null) { return null; } Matcher m = charsetPattern.matcher(contentType); if (m.find()) { return m.group(1).trim(); } return null; } public static Map<String, String> splitQuery(String queryString) { if (queryString == null || queryString.isEmpty()) { return null; } return splitMultiValueQuery(queryString).entrySet().stream() .collect(toMap(Entry::getKey, x -> join(",", x.getValue()))); } public static Map<String, List<String>> splitMultiValueQuery(String queryString) { if (queryString == null || queryString.isEmpty()) { return null; } return Arrays.stream(queryString.split("&")).map(LambdaHandler::splitQueryParameter) .collect(groupingBy(Entry::getKey, LinkedHashMap::new, mapping(Entry::getValue, toList()))); } public static SimpleImmutableEntry<String, String> splitQueryParameter(String it) { final int idx = it.indexOf("="); final String key = idx > 0 ? it.substring(0, idx) : it; final String value = idx > 0 && it.length() > idx + 1 ? it.substring(idx + 1) : ""; return new SimpleImmutableEntry<>(key, value); } static class LambdaErrorResponse { public final String errorMessage; public final String errorType; public final List<String> stackTrace; public LambdaErrorResponse(Throwable t) { this.errorMessage = t.getMessage(); this.errorType = t.getClass().getName(); StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); t.printStackTrace(pw); String[] stack = sw.getBuffer().toString().split("\n"); stackTrace = Arrays.asList(stack); } public String toString() { return JsonObject.mapFrom(this).toString(); } } }