com.netflix.zuul.http.HttpServletRequestWrapper.java Source code

Java tutorial

Introduction

Here is the source code for com.netflix.zuul.http.HttpServletRequestWrapper.java

Source

/*
 * Copyright 2013 Netflix, Inc.
 *
 *      Licensed under the Apache License, Version 2.0 (the "License");
 *      you may not use this file except in compliance with the License.
 *      You may obtain a copy of the License at
 *
 *          http://www.apache.org/licenses/LICENSE-2.0
 *
 *      Unless required by applicable law or agreed to in writing, software
 *      distributed under the License is distributed on an "AS IS" BASIS,
 *      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *      See the License for the specific language governing permissions and
 *      limitations under the License.
 */
package com.netflix.zuul.http;

import com.netflix.zuul.constants.ZuulHeaders;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.util.HTTPRequestUtils;
import org.apache.commons.io.IOUtils;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.RequestDispatcher;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import java.io.*;
import java.net.URLDecoder;
import java.security.Principal;
import java.util.*;
import java.util.zip.*;

import static org.junit.Assert.*;
import static org.mockito.Mockito.when;

/**
 * This class implements the Wrapper or Decorator pattern.<br/>
 * Methods default to calling through to the wrapped request object,
 * except the ones that read the request's content (parameters, stream or reader).
 * <p/>
 * This class provides a buffered content reading that allows the methods
 * {@link #getReader()}, {@link #getInputStream()} and any of the getParameterXXX to be     called
 * safely and repeatedly with the same results.
 * <p/>
 * This class is intended to wrap relatively small HttpServletRequest instances.
 *
 * @author pgurov
 */
public class HttpServletRequestWrapper implements HttpServletRequest {

    private final static HashMap<String, String[]> EMPTY_MAP = new HashMap<String, String[]>();
    protected static final Logger LOG = LoggerFactory.getLogger(HttpServletRequestWrapper.class);

    private HttpServletRequest req;
    private byte[] contentData;
    private HashMap<String, String[]> parameters;

    public HttpServletRequestWrapper() {
        //a trick for Groovy
        throw new IllegalArgumentException(
                "Please use HttpServletRequestWrapper(HttpServletRequest request) constructor!");
    }

    private HttpServletRequestWrapper(HttpServletRequest request, byte[] contentData,
            HashMap<String, String[]> parameters) {
        req = request;
        this.contentData = contentData;
        this.parameters = parameters;
    }

    public HttpServletRequestWrapper(HttpServletRequest request) {
        if (request == null)
            throw new IllegalArgumentException("The HttpServletRequest is null!");
        req = request;
    }

    /**
     * Returns the wrapped HttpServletRequest.
     * Using the getParameterXXX(), getInputStream() or getReader() methods may interfere
     * with this class operation.
     *
     * @return The wrapped HttpServletRequest.
     */
    public HttpServletRequest getRequest() {
        try {
            parseRequest();
        } catch (IOException e) {
            throw new IllegalStateException("Cannot parse the request!", e);
        }
        return new HttpServletRequestWrapper(req, contentData, parameters);
    }

    /**
     * This method is safe to use multiple times.
     * Changing the returned array will not interfere with this class operation.
     *
     * @return The cloned content data.
     */
    public byte[] getContentData() {
        return contentData.clone();
    }

    /**
     * This method is safe to use multiple times.
     * Changing the returned map or the array of any of the map's values will not
     * interfere with this class operation.
     *
     * @return The cloned parameters map.
     */

    public HashMap<String, String[]> getParameters() {
        if (parameters == null)
            return EMPTY_MAP;
        HashMap<String, String[]> map = new HashMap<String, String[]>(parameters.size() * 2);
        for (String key : parameters.keySet()) {
            map.put(key, parameters.get(key).clone());
        }
        return map;
    }

    private void parseRequest() throws IOException {
        if (parameters != null)
            return; //already parsed

        HashMap<String, List<String>> mapA = new HashMap<String, List<String>>();
        List<String> list;

        Map<String, List<String>> query = HTTPRequestUtils.getInstance().getQueryParams();
        if (query != null) {
            for (String key : query.keySet()) {
                list = query.get(key);
                mapA.put(key, list);
            }
        }

        if (req.getContentLength() > 0) {
            byte[] data = new byte[req.getContentLength()];
            int len = 0, totalLen = 0;
            InputStream is = req.getInputStream();
            while (totalLen < data.length) {
                totalLen += (len = is.read(data, totalLen, data.length - totalLen));
                if (len < 1)
                    throw new IOException(
                            "Cannot read more than " + totalLen + (totalLen == 1 ? " byte!" : " bytes!"));
            }
            contentData = data;

            String enc = req.getCharacterEncoding();

            if (enc == null)
                enc = "UTF-8";
            String s = new String(data, enc), name, value;
            StringTokenizer st = new StringTokenizer(s, "&");
            int i;

            boolean decode = req.getContentType() != null
                    && req.getContentType().equalsIgnoreCase("application/x-www-form-urlencoded");
            while (st.hasMoreTokens()) {
                s = st.nextToken();
                i = s.indexOf("=");
                if (i > 0 && s.length() > i + 1) {
                    name = s.substring(0, i);
                    value = s.substring(i + 1);
                    if (decode) {
                        try {
                            name = URLDecoder.decode(name, "UTF-8");
                        } catch (Exception e) {
                        }
                        try {
                            value = URLDecoder.decode(value, "UTF-8");
                        } catch (Exception e) {
                        }
                    }
                    list = mapA.get(name);
                    if (list == null) {
                        list = new LinkedList<String>();
                        mapA.put(name, list);
                    }
                    list.add(value);
                }
            }

        } else if (req.getContentLength() == -1) {
            final String transferEncoding = req.getHeader(ZuulHeaders.TRANSFER_ENCODING);
            if (transferEncoding != null && transferEncoding.equals(ZuulHeaders.CHUNKED))
                RequestContext.getCurrentContext().setChunkedRequestBody();
        }

        HashMap<String, String[]> map = new HashMap<String, String[]>(mapA.size() * 2);
        for (String key : mapA.keySet()) {
            list = mapA.get(key);
            map.put(key, list.toArray(new String[list.size()]));
        }

        parameters = map;

    }

    /**
     * This method is safe to call multiple times.
     * Calling it will not interfere with getParameterXXX() or getReader().
     * Every time a new ServletInputStream is returned that reads data from the begining.
     *
     * @return A new ServletInputStream.
     */
    public ServletInputStream getInputStream() throws IOException {
        parseRequest();

        if (RequestContext.getCurrentContext().isChunkedRequestBody()) {
            return req.getInputStream();
        } else {
            return new ServletInputStreamWrapper(contentData);
        }
    }

    /**
     * This method is safe to call multiple times.
     * Calling it will not interfere with getParameterXXX() or getInputStream().
     * Every time a new BufferedReader is returned that reads data from the begining.
     *
     * @return A new BufferedReader with the wrapped request's character encoding (or UTF-8 if null).
     */
    public BufferedReader getReader() throws IOException {
        parseRequest();

        String enc = req.getCharacterEncoding();
        if (enc == null)
            enc = "UTF-8";
        return new BufferedReader(new InputStreamReader(new ByteArrayInputStream(contentData), enc));
    }

    /**
     * This method is safe to execute multiple times.
     *
     * @see javax.servlet.ServletRequest#getParameter(java.lang.String)
     */
    public String getParameter(String name) {
        try {
            parseRequest();
        } catch (IOException e) {
            throw new IllegalStateException("Cannot parse the request!", e);
        }
        if (parameters == null)
            return null;
        String[] values = parameters.get(name);
        if (values == null || values.length == 0)
            return null;
        return values[0];
    }

    /**
     * This method is safe.
     *
     * @see {@link #getParameters()}
     * @see javax.servlet.ServletRequest#getParameterMap()
     */
    @SuppressWarnings("unchecked")
    public Map getParameterMap() {
        try {
            parseRequest();
        } catch (IOException e) {
            throw new IllegalStateException("Cannot parse the request!", e);
        }
        return getParameters();
    }

    /**
     * This method is safe to execute multiple times.
     *
     * @see javax.servlet.ServletRequest#getParameterNames()
     */
    @SuppressWarnings("unchecked")
    public Enumeration getParameterNames() {
        try {
            parseRequest();
        } catch (IOException e) {
            throw new IllegalStateException("Cannot parse the request!", e);
        }
        return new Enumeration<String>() {
            private String[] arr = getParameters().keySet().toArray(new String[0]);
            private int idx = 0;

            public boolean hasMoreElements() {
                return idx < arr.length;
            }

            public String nextElement() {
                return arr[idx++];
            }

        };
    }

    /**
     * This method is safe to execute multiple times.
     * Changing the returned array will not interfere with this class operation.
     *
     * @see javax.servlet.ServletRequest#getParameterValues(java.lang.String)
     */
    public String[] getParameterValues(String name) {
        try {
            parseRequest();
        } catch (IOException e) {
            throw new IllegalStateException("Cannot parse the request!", e);
        }
        if (parameters == null)
            return null;
        String[] arr = parameters.get(name);
        if (arr == null)
            return null;
        return arr.clone();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getAuthType()
    */
    public String getAuthType() {
        return req.getAuthType();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getContextPath()
    */
    public String getContextPath() {
        return req.getContextPath();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getCookies()
    */
    public Cookie[] getCookies() {
        return req.getCookies();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getDateHeader(java.lang.String)
    */
    public long getDateHeader(String name) {
        return req.getDateHeader(name);
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getHeader(java.lang.String)
    */
    public String getHeader(String name) {
        return req.getHeader(name);
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getHeaderNames()
    */
    @SuppressWarnings("unchecked")
    public Enumeration getHeaderNames() {
        return req.getHeaderNames();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getHeaders(java.lang.String)
    */
    @SuppressWarnings("unchecked")
    public Enumeration getHeaders(String name) {
        return req.getHeaders(name);
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getIntHeader(java.lang.String)
    */
    public int getIntHeader(String name) {
        return req.getIntHeader(name);
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getMethod()
    */
    public String getMethod() {
        return req.getMethod();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getPathInfo()
    */
    public String getPathInfo() {
        return req.getPathInfo();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getPathTranslated()
    */
    public String getPathTranslated() {
        return req.getPathTranslated();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getQueryString()
    */
    public String getQueryString() {
        return req.getQueryString();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getRemoteUser()
    */
    public String getRemoteUser() {
        return req.getRemoteUser();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getRequestURI()
    */
    public String getRequestURI() {
        return req.getRequestURI();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getRequestURL()
    */
    public StringBuffer getRequestURL() {
        return req.getRequestURL();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getRequestedSessionId()
    */
    public String getRequestedSessionId() {
        return req.getRequestedSessionId();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getServletPath()
    */
    public String getServletPath() {
        return req.getServletPath();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getSession()
    */
    public HttpSession getSession() {
        return req.getSession();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getSession(boolean)
    */
    public HttpSession getSession(boolean create) {
        return req.getSession(create);
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#getUserPrincipal()
    */
    public Principal getUserPrincipal() {
        return req.getUserPrincipal();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#isRequestedSessionIdFromCookie()
    */
    public boolean isRequestedSessionIdFromCookie() {
        return req.isRequestedSessionIdFromCookie();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#isRequestedSessionIdFromURL()
    */
    public boolean isRequestedSessionIdFromURL() {
        return req.isRequestedSessionIdFromURL();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#isRequestedSessionIdFromUrl()
    */
    @SuppressWarnings("deprecation")
    public boolean isRequestedSessionIdFromUrl() {
        return req.isRequestedSessionIdFromUrl();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#isRequestedSessionIdValid()
    */
    public boolean isRequestedSessionIdValid() {
        return req.isRequestedSessionIdValid();
    }

    /* (non-Javadoc)
    * @see javax.servlet.http.HttpServletRequest#isUserInRole(java.lang.String)
    */
    public boolean isUserInRole(String role) {
        return req.isUserInRole(role);
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getAttribute(java.lang.String)
    */
    public Object getAttribute(String name) {
        return req.getAttribute(name);
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getAttributeNames()
    */
    @SuppressWarnings("unchecked")
    public Enumeration getAttributeNames() {
        return req.getAttributeNames();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getCharacterEncoding()
    */
    public String getCharacterEncoding() {
        return req.getCharacterEncoding();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getContentLength()
    */
    public int getContentLength() {
        return req.getContentLength();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getContentType()
    */
    public String getContentType() {
        return req.getContentType();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getLocalAddr()
    */
    public String getLocalAddr() {
        return req.getLocalAddr();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getLocalName()
    */
    public String getLocalName() {
        return req.getLocalName();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getLocalPort()
    */
    public int getLocalPort() {
        return req.getLocalPort();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getLocale()
    */
    public Locale getLocale() {
        return req.getLocale();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getLocales()
    */
    @SuppressWarnings("unchecked")
    public Enumeration getLocales() {
        return req.getLocales();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getProtocol()
    */
    public String getProtocol() {
        return req.getProtocol();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getRealPath(java.lang.String)
    */
    @SuppressWarnings("deprecation")
    public String getRealPath(String path) {
        return req.getRealPath(path);
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getRemoteAddr()
    */
    public String getRemoteAddr() {
        return req.getRemoteAddr();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getRemoteHost()
    */
    public String getRemoteHost() {
        return req.getRemoteHost();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getRemotePort()
    */
    public int getRemotePort() {
        return req.getRemotePort();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getRequestDispatcher(java.lang.String)
    */
    public RequestDispatcher getRequestDispatcher(String path) {
        return req.getRequestDispatcher(path);
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getScheme()
    */
    public String getScheme() {
        return req.getScheme();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getServerName()
    */
    public String getServerName() {
        return req.getServerName();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#getServerPort()
    */
    public int getServerPort() {
        return req.getServerPort();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#isSecure()
    */
    public boolean isSecure() {
        return req.isSecure();
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#removeAttribute(java.lang.String)
    */
    public void removeAttribute(String name) {
        req.removeAttribute(name);
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#setAttribute(java.lang.String, java.lang.Object)
    */
    public void setAttribute(String name, Object value) {
        req.setAttribute(name, value);
    }

    /* (non-Javadoc)
    * @see javax.servlet.ServletRequest#setCharacterEncoding(java.lang.String)
    */
    public void setCharacterEncoding(String env) throws UnsupportedEncodingException {
        req.setCharacterEncoding(env);
    }

    public static final class UnitTest {

        @Mock
        HttpServletRequest request;

        @Before
        public void before() {
            RequestContext.getCurrentContext().unset();
            MockitoAnnotations.initMocks(this);

            RequestContext.getCurrentContext().setRequest(request);
        }

        private void body(byte[] body) throws IOException {
            when(request.getInputStream()).thenReturn(new ServletInputStreamWrapper(body));
            when(request.getContentLength()).thenReturn(body.length);
        }

        @Test
        public void handlesDuplicateParams() {
            when(request.getQueryString()).thenReturn("path=one&key1=val1&path=two");
            final HttpServletRequestWrapper w = new HttpServletRequestWrapper(request);

            // getParameters doesn't call parseRequest internally, not sure why
            // so I'm forcing it here
            w.getParameterMap();

            final Map<String, String[]> params = w.getParameters();
            assertFalse("params should not be empty", params.isEmpty());
            final String[] paths = params.get("path");
            assertTrue("paths param should not be empty", paths.length > 0);
            assertEquals("one", paths[0]);
            assertEquals("two", paths[1]);
        }

        @Test
        public void handlesPlainRequestBody() throws IOException {
            final String body = "hello";
            body(body.getBytes());

            final HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request);
            assertEquals(body, IOUtils.toString(wrapper.getInputStream()));
        }

        @Test
        public void handlesGzipRequestBody() throws IOException {
            // creates string, gzips into byte array which will be mocked as InputStream of request
            final String body = "hello";
            final byte[] bodyBytes = body.getBytes();
            // in this case the compressed stream is actually larger - need to allocate enough space
            final ByteArrayOutputStream byteOutStream = new ByteArrayOutputStream(0);
            final GZIPOutputStream gzipOutStream = new GZIPOutputStream(byteOutStream);
            gzipOutStream.write(bodyBytes);
            gzipOutStream.finish();
            gzipOutStream.flush();
            body(byteOutStream.toByteArray());

            final HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request);
            assertEquals(body, IOUtils.toString(new GZIPInputStream(wrapper.getInputStream())));
        }

        @Test
        public void handlesZipRequestBody() throws IOException {

            final String body = "hello";
            final byte[] bodyBytes = body.getBytes();

            final ByteArrayOutputStream byteOutStream = new ByteArrayOutputStream(0);
            ZipOutputStream zOutput = new ZipOutputStream(byteOutStream);

            zOutput.putNextEntry(new ZipEntry("f1"));
            zOutput.write(bodyBytes);
            zOutput.finish();
            zOutput.flush();
            body(byteOutStream.toByteArray());

            final HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request);

            assertEquals(body, readZipInputStream(wrapper.getInputStream()));

        }

        public String readZipInputStream(InputStream input) throws IOException {

            byte[] uploadedBytes = getBytesFromInputStream(input);
            input.close();

            /* try to read it as a zip file */
            String uploadFileTxt = null;
            ZipInputStream zInput = new ZipInputStream(new ByteArrayInputStream(uploadedBytes));
            ZipEntry zipEntry = zInput.getNextEntry();
            if (zipEntry != null) {
                // we have a ZipEntry, so this is a zip file
                while (zipEntry != null) {
                    byte[] fileBytes = getBytesFromInputStream(zInput);
                    uploadFileTxt = new String(fileBytes);

                    zipEntry = zInput.getNextEntry();
                }
            }
            return uploadFileTxt;
        }

        private byte[] getBytesFromInputStream(InputStream input) throws IOException {
            int v = 0;
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            while ((v = input.read()) != -1) {
                bos.write(v);
            }
            bos.close();
            return bos.toByteArray();
        }

    }

}