com.amazonaws.serverless.proxy.internal.servlet.AwsProxyHttpServletRequest.java Source code

Java tutorial

Introduction

Here is the source code for com.amazonaws.serverless.proxy.internal.servlet.AwsProxyHttpServletRequest.java

Source

/*
 * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package com.amazonaws.serverless.proxy.internal.servlet;

import com.amazonaws.serverless.proxy.internal.model.AwsProxyRequest;
import com.amazonaws.services.lambda.runtime.Context;

import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileUploadException;
import org.apache.commons.fileupload.servlet.ServletFileUpload;

import javax.servlet.AsyncContext;
import javax.servlet.DispatcherType;
import javax.servlet.ReadListener;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpUpgradeHandler;
import javax.servlet.http.Part;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.SecurityContext;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringReader;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.security.Principal;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;

/**
 * Implementation of the <code>HttpServletRequest</code> interface that supports <code>AwsProxyRequest</code> object.
 * This object is initialized with an <code>AwsProxyRequest</code> event and a <code>SecurityContext</code> generated
 * by an implementation of the <code>SecurityContextWriter</code>.
 */
public class AwsProxyHttpServletRequest implements HttpServletRequest {

    //-------------------------------------------------------------
    // Constants
    //-------------------------------------------------------------

    private static final String HEADER_KEY_VALUE_SEPARATOR = "=";
    private static final String HEADER_VALUE_SEPARATOR = ";";
    private static final String FORM_DATA_SEPARATOR = "&";
    private static final String DEFAULT_CHARACTER_ENCODING = "UTF-8";
    private static final String HEADER_DATE_FORMAT = "EEE, d MMM yyyy HH:mm:ss z";
    private static final String ENCODING_VALUE_KEY = "charset";

    // We need this to pickup the protocol from the CloudFront header since Lambda doesn't receive this
    // information from anywhere else
    static final String CF_PROTOCOL_HEADER_NAME = "CloudFront-Forwarded-Proto";

    //-------------------------------------------------------------
    // Variables - Private
    //-------------------------------------------------------------

    private AwsProxyRequest request;
    private Context lamdaContext;
    private SecurityContext securityContext;
    private Map<String, Object> attributes;
    private Map<String, List<String>> urlEncodedFormParameters;
    private Map<String, Part> multipartFormParameters;

    //-------------------------------------------------------------
    // Constructors
    //-------------------------------------------------------------

    public AwsProxyHttpServletRequest(AwsProxyRequest awsProxyRequest, Context lamdaContext,
            SecurityContext awsSecurityContext) {
        this.request = awsProxyRequest;
        this.lamdaContext = lamdaContext;
        this.securityContext = awsSecurityContext;

        this.attributes = new HashMap<>();
        this.urlEncodedFormParameters = getFormUrlEncodedParametersMap();
        this.multipartFormParameters = getMultipartFormParametersMap();
    }

    //-------------------------------------------------------------
    // Implementation - HttpServletRequest
    //-------------------------------------------------------------

    @Override
    public String getAuthType() {
        return securityContext.getAuthenticationScheme();
    }

    @Override
    public Cookie[] getCookies() {
        String cookieHeader = getHeaderCaseInsensitive(HttpHeaders.COOKIE);
        if (cookieHeader == null) {
            return new Cookie[0];
        }
        String[] cookies = cookieHeader.split(HEADER_VALUE_SEPARATOR);
        List<Cookie> output = new ArrayList<>();

        for (String curCookie : cookies) {
            String[] cookieKeyValue = curCookie.split(HEADER_KEY_VALUE_SEPARATOR);
            if (cookieKeyValue.length < 2) {
                continue;
            }
            output.add(new Cookie(cookieKeyValue[0].trim(), cookieKeyValue[1].trim()));
            // TODO: Parse the full cookie
        }
        Cookie[] returnValue = new Cookie[output.size()];
        return output.toArray(returnValue);
    }

    @Override
    public long getDateHeader(String s) {
        String dateString = getHeaderCaseInsensitive(HttpHeaders.DATE);
        if (dateString == null) {
            return new Date().getTime();
        }
        SimpleDateFormat dateFormatter = new SimpleDateFormat(HEADER_DATE_FORMAT);
        try {
            return dateFormatter.parse(dateString).getTime();
        } catch (ParseException e) {
            e.printStackTrace();
            return new Date().getTime();
        }
    }

    @Override
    public String getHeader(String s) {
        return getHeaderCaseInsensitive(s);
    }

    @Override
    public Enumeration<String> getHeaders(String s) {
        String headerValue = getHeaderCaseInsensitive(s);
        if (headerValue == null) {
            return Collections.enumeration(new ArrayList<String>());
        }
        List<String> valueCollection = new ArrayList<>();
        valueCollection.add(headerValue);
        return Collections.enumeration(valueCollection);
    }

    @Override
    public Enumeration<String> getHeaderNames() {
        if (request.getHeaders() == null) {
            return Collections.emptyEnumeration();
        }
        return Collections.enumeration(request.getHeaders().keySet());
    }

    @Override
    public int getIntHeader(String s) {
        String headerValue = getHeaderCaseInsensitive(s);
        if (headerValue == null) {
            return -1;
        }

        return Integer.parseInt(headerValue);
    }

    @Override
    public String getMethod() {
        return request.getHttpMethod();
    }

    @Override
    public String getPathInfo() {
        String pathInfo = getServletPath().replace(getContextPath(), "");
        if (!pathInfo.startsWith("/")) {
            pathInfo = "/" + pathInfo;
        }
        return pathInfo;
    }

    @Override
    public String getPathTranslated() {
        // Return null because it is an archive on a remote system
        return null;
    }

    @Override
    public String getContextPath() {
        return "/";
    }

    @Override
    public String getQueryString() {
        return request.getQueryString().isEmpty() ? null : request.getQueryString();
    }

    @Override
    public String getRemoteUser() {
        return securityContext.getUserPrincipal().getName();
    }

    @Override
    public boolean isUserInRole(String s) {
        // TODO: Not supported?
        return false;
    }

    @Override
    public Principal getUserPrincipal() {
        return securityContext.getUserPrincipal();
    }

    @Override
    public String getRequestedSessionId() {
        return null;
    }

    @Override
    public String getRequestURI() {
        return request.getPath();
    }

    @Override
    public StringBuffer getRequestURL() {
        String url = "";
        url += getHeaderCaseInsensitive(HttpHeaders.HOST);
        url += "/";
        url += request.getPath();
        return new StringBuffer(url);
    }

    @Override
    public String getServletPath() {
        return request.getPath();
    }

    @Override
    public HttpSession getSession(boolean b) {
        return null;
    }

    @Override
    public HttpSession getSession() {
        return null;
    }

    @Override
    public String changeSessionId() {
        return null;
    }

    @Override
    public boolean isRequestedSessionIdValid() {
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromCookie() {
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromURL() {
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromUrl() {
        return false;
    }

    @Override
    public boolean authenticate(HttpServletResponse httpServletResponse) throws IOException, ServletException {
        return false;
    }

    @Override
    public void login(String s, String s1) throws ServletException {

    }

    @Override
    public void logout() throws ServletException {

    }

    @Override
    public Collection<Part> getParts() throws IOException, ServletException {
        return multipartFormParameters.values();
    }

    @Override
    public Part getPart(String s) throws IOException, ServletException {
        return multipartFormParameters.get(s);
    }

    @Override
    public <T extends HttpUpgradeHandler> T upgrade(Class<T> aClass) throws IOException, ServletException {
        return null;
    }

    //-------------------------------------------------------------
    // Implementation - ServletRequest
    //-------------------------------------------------------------

    @Override
    public Object getAttribute(String s) {
        return attributes.get(s);
    }

    @Override
    public Enumeration<String> getAttributeNames() {
        return Collections.enumeration(attributes.keySet());
    }

    @Override
    public String getCharacterEncoding() {
        // we only look at content-type because content-encoding should only be used for
        // "binary" requests such as gzip/deflate.
        String contentTypeHeader = getHeaderCaseInsensitive(HttpHeaders.CONTENT_TYPE);
        if (contentTypeHeader == null) {
            return null;
        }

        String[] contentTypeValues = contentTypeHeader.split(HEADER_VALUE_SEPARATOR);
        if (contentTypeValues.length <= 1) {
            return null;
        }

        for (String contentTypeValue : contentTypeValues) {
            if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) {
                String[] encodingValues = contentTypeValue.split(HEADER_KEY_VALUE_SEPARATOR);
                if (encodingValues.length <= 1) {
                    return null;
                }
                return encodingValues[1];
            }
        }
        return null;
    }

    @Override
    public void setCharacterEncoding(String s) throws UnsupportedEncodingException {
        String currentContentType = request.getHeaders().get(HttpHeaders.CONTENT_TYPE);
        if (currentContentType == null) {
            request.getHeaders().put(HttpHeaders.CONTENT_TYPE,
                    HEADER_VALUE_SEPARATOR + " " + ENCODING_VALUE_KEY + HEADER_KEY_VALUE_SEPARATOR + s);
            return;
        }

        if (currentContentType.contains(HEADER_VALUE_SEPARATOR)) {
            String[] contentTypeValues = currentContentType.split(HEADER_VALUE_SEPARATOR);
            String contentType = contentTypeValues[0];

            for (String contentTypeValue : contentTypeValues) {
                if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) {
                    contentType += HEADER_VALUE_SEPARATOR + " " + ENCODING_VALUE_KEY + HEADER_KEY_VALUE_SEPARATOR
                            + s;
                } else {
                    contentType += HEADER_VALUE_SEPARATOR + " " + contentTypeValue;
                }
            }

            request.getHeaders().put(HttpHeaders.CONTENT_TYPE, contentType);
        } else {
            request.getHeaders().put(HttpHeaders.CONTENT_TYPE, currentContentType + HEADER_VALUE_SEPARATOR + " "
                    + ENCODING_VALUE_KEY + HEADER_KEY_VALUE_SEPARATOR + s);
        }
    }

    @Override
    public int getContentLength() {
        String headerValue = getHeaderCaseInsensitive(HttpHeaders.CONTENT_LENGTH);
        if (headerValue == null) {
            return -1;
        }
        return Integer.parseInt(headerValue);
    }

    @Override
    public long getContentLengthLong() {
        String headerValue = getHeaderCaseInsensitive(HttpHeaders.CONTENT_LENGTH);
        if (headerValue == null) {
            return -1;
        }
        return Long.parseLong(headerValue);
    }

    @Override
    public String getContentType() {
        return getHeaderCaseInsensitive(HttpHeaders.CONTENT_TYPE);
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        byte[] bodyBytes = request.getBody().getBytes();
        if (request.isBase64Encoded()) {
            bodyBytes = Base64.getDecoder().decode(request.getBody());
        }
        ByteArrayInputStream requestBodyStream = new ByteArrayInputStream(bodyBytes);
        return new ServletInputStream() {

            private ReadListener listener;

            @Override
            public boolean isFinished() {
                return true;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
                listener = readListener;
                try {
                    listener.onDataAvailable();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }

            @Override
            public int read() throws IOException {
                int readByte = requestBodyStream.read();
                if (requestBodyStream.available() == 0 && listener != null) {
                    listener.onAllDataRead();
                }
                return readByte;
            }
        };
    }

    @Override
    public String getParameter(String s) {
        String queryStringParameter = getQueryStringParameterCaseInsensitive(s);
        if (queryStringParameter != null) {
            return queryStringParameter;
        }

        String[] bodyParams = getFormBodyParameterCaseInsensitive(s);
        if (bodyParams == null || bodyParams.length == 0) {
            return null;
        } else {
            return bodyParams[0];
        }
    }

    @Override
    public Enumeration<String> getParameterNames() {
        List<String> paramNames = new ArrayList<>();
        paramNames.addAll(request.getQueryStringParameters().keySet());
        paramNames.addAll(urlEncodedFormParameters.keySet());
        return Collections.enumeration(paramNames);
    }

    @Override
    public String[] getParameterValues(String s) {
        List<String> values = new ArrayList<>();
        String queryStringValue = getQueryStringParameterCaseInsensitive(s);
        if (queryStringValue != null) {
            values.add(queryStringValue);
        }

        String[] formBodyValues = getFormBodyParameterCaseInsensitive(s);
        if (formBodyValues != null) {
            values.addAll(Arrays.asList(formBodyValues));
        }

        if (values.size() == 0) {
            return null;
        } else {
            String[] valuesArray = new String[values.size()];
            valuesArray = values.toArray(valuesArray);
            return valuesArray;
        }
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> output = new HashMap<>();

        Map<String, List<String>> params = urlEncodedFormParameters;
        if (params == null) {
            params = new HashMap<>();
        }

        for (Map.Entry<String, String> entry : request.getQueryStringParameters().entrySet()) {
            if (params.containsKey(entry.getKey())) {
                params.get(entry.getKey()).add(entry.getValue());
            } else {
                List<String> valueList = new ArrayList<>();
                valueList.add(entry.getValue());
                params.put(entry.getKey(), valueList);
            }
        }

        for (Map.Entry<String, List<String>> entry : params.entrySet()) {
            String[] valuesArray = new String[entry.getValue().size()];
            valuesArray = entry.getValue().toArray(valuesArray);
            output.put(entry.getKey(), valuesArray);
        }
        return output;
    }

    @Override
    public String getProtocol() {
        return null;
    }

    @Override
    public String getScheme() {
        String headerValue = getHeaderCaseInsensitive(CF_PROTOCOL_HEADER_NAME);
        if (headerValue == null) {
            return "https";
        }
        return headerValue;
    }

    @Override
    public String getServerName() {
        return "lambda.amazonaws.com";
    }

    @Override
    public int getServerPort() {
        return 0;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new StringReader(request.getBody()));
    }

    @Override
    public String getRemoteAddr() {
        return request.getRequestContext().getIdentity().getSourceIp();
    }

    @Override
    public String getRemoteHost() {
        return getHeaderCaseInsensitive(HttpHeaders.HOST);
    }

    @Override
    public void setAttribute(String s, Object o) {
        attributes.put(s, o);
    }

    @Override
    public void removeAttribute(String s) {
        attributes.remove(s);
    }

    @Override
    public Locale getLocale() {
        String localeHeader = getHeaderCaseInsensitive(HttpHeaders.ACCEPT_LANGUAGE);
        if (localeHeader == null) {
            return Locale.getDefault();
        }
        if (localeHeader.contains(HEADER_VALUE_SEPARATOR)) {
            localeHeader = localeHeader.split(HEADER_VALUE_SEPARATOR)[0].trim();
        }
        return new Locale(localeHeader);
    }

    @Override
    public Enumeration<Locale> getLocales() {
        String localeHeader = getHeaderCaseInsensitive(HttpHeaders.ACCEPT_LANGUAGE);
        List<Locale> locales = new ArrayList<>();
        if (localeHeader == null) {
            locales.add(Locale.getDefault());
        } else {
            if (localeHeader.contains(HEADER_VALUE_SEPARATOR)) {
                for (String locale : localeHeader.split(HEADER_VALUE_SEPARATOR)) {
                    locales.add(new Locale(locale.trim()));
                }
            } else {
                locales.add(new Locale(localeHeader.trim()));
            }
        }
        return Collections.enumeration(locales);
    }

    @Override
    public boolean isSecure() {
        return securityContext.isSecure();
    }

    @Override
    public RequestDispatcher getRequestDispatcher(String s) {
        return null;
    }

    @Override
    public String getRealPath(String s) {
        // we are in an archive on a remote server
        return null;
    }

    @Override
    public int getRemotePort() {
        return 0;
    }

    @Override
    public String getLocalName() {
        return "lambda.amazonaws.com";
    }

    @Override
    public String getLocalAddr() {
        return null;
    }

    @Override
    public int getLocalPort() {
        return 0;
    }

    @Override
    public ServletContext getServletContext() {
        return AwsProxyServletContext.getInstance(request, lamdaContext);
    }

    @Override
    public AsyncContext startAsync() throws IllegalStateException {
        return null;
    }

    @Override
    public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
            throws IllegalStateException {
        return null;
    }

    @Override
    public boolean isAsyncStarted() {
        return false;
    }

    @Override
    public boolean isAsyncSupported() {
        return false;
    }

    @Override
    public AsyncContext getAsyncContext() {
        return null;
    }

    @Override
    public DispatcherType getDispatcherType() {
        return null;
    }

    //-------------------------------------------------------------
    // Methods - Private
    //-------------------------------------------------------------

    private String getHeaderCaseInsensitive(String key) {
        if (request.getHeaders() == null) {
            return null;
        }
        for (String requestHeaderKey : request.getHeaders().keySet()) {
            if (key.toLowerCase().equals(requestHeaderKey.toLowerCase())) {
                return request.getHeaders().get(requestHeaderKey);
            }
        }
        return null;
    }

    private String getQueryStringParameterCaseInsensitive(String key) {
        if (request.getQueryStringParameters() == null) {
            return null;
        }

        for (String requestParamKey : request.getQueryStringParameters().keySet()) {
            if (key.toLowerCase().equals(requestParamKey.toLowerCase())) {
                return request.getQueryStringParameters().get(requestParamKey);
            }
        }
        return null;
    }

    private String[] getFormBodyParameterCaseInsensitive(String key) {
        List<String> values = urlEncodedFormParameters.get(key);
        if (values != null) {
            String[] valuesArray = new String[values.size()];
            valuesArray = values.toArray(valuesArray);
            return valuesArray;
        } else {
            return null;
        }
    }

    private Map<String, Part> getMultipartFormParametersMap() {
        if (!ServletFileUpload.isMultipartContent(this)) { // isMultipartContent also checks the content type
            return new HashMap<>();
        }

        Map<String, Part> output = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);

        ServletFileUpload upload = new ServletFileUpload();
        try {
            List<FileItem> items = upload.parseRequest(this);
            for (FileItem item : items) {
                AwsProxyRequestPart newPart = new AwsProxyRequestPart(item.get());
                newPart.setName(item.getName());
                newPart.setSubmittedFileName(item.getFieldName());
                newPart.setContentType(item.getContentType());
                newPart.setSize(item.getSize());

                Iterator<String> headerNamesIterator = item.getHeaders().getHeaderNames();
                while (headerNamesIterator.hasNext()) {
                    String headerName = headerNamesIterator.next();
                    Iterator<String> headerValuesIterator = item.getHeaders().getHeaders(headerName);
                    while (headerValuesIterator.hasNext()) {
                        newPart.addHeader(headerName, headerValuesIterator.next());
                    }
                }

                output.put(item.getFieldName(), newPart);
            }
        } catch (FileUploadException e) {
            // TODO: Should we swallaw this?
            e.printStackTrace();
        }
        return output;
    }

    private Map<String, List<String>> getFormUrlEncodedParametersMap() {
        String contentType = getContentType();
        if (contentType == null) {
            return new HashMap<>();
        }
        if (!contentType.startsWith(MediaType.APPLICATION_FORM_URLENCODED)
                || !getMethod().toLowerCase().equals("post")) {
            return new HashMap<>();
        }
        String rawBodyContent;
        try {
            rawBodyContent = URLDecoder.decode(request.getBody(), DEFAULT_CHARACTER_ENCODING);
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
            rawBodyContent = request.getBody();
        }

        Map<String, List<String>> output = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
        for (String parameter : rawBodyContent.split(FORM_DATA_SEPARATOR)) {
            String[] parameterKeyValue = parameter.split(HEADER_KEY_VALUE_SEPARATOR);
            if (parameterKeyValue.length < 2) {
                continue;
            }
            List<String> values = new ArrayList<>();
            if (output.containsKey(parameterKeyValue[0])) {
                values = output.get(parameterKeyValue[0]);
            }
            values.add(parameterKeyValue[1]);
            output.put(parameterKeyValue[0], values);
        }

        return output;
    }
}