com.streamsets.pipeline.lib.http.HttpReceiverServlet.java Source code

Java tutorial

Introduction

Here is the source code for com.streamsets.pipeline.lib.http.HttpReceiverServlet.java

Source

/*
 * Copyright 2017 StreamSets Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.streamsets.pipeline.lib.http;

import com.codahale.metrics.Meter;
import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting;
import com.streamsets.pipeline.api.Stage;
import com.streamsets.pipeline.api.StageException;
import com.streamsets.pipeline.api.impl.Utils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.iq80.snappy.SnappyFramedInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.zip.GZIPInputStream;

@SuppressWarnings({ "squid:S2226", "squid:S1989", "squid:S1948" })
public class HttpReceiverServlet extends HttpServlet {
    private static final Logger LOG = LoggerFactory.getLogger(HttpReceiverServlet.class);

    private final HttpReceiver receiver;
    private final BlockingQueue<Exception> errorQueue;
    private final Meter invalidRequestMeter;
    protected final Meter errorRequestMeter;
    protected final Meter requestMeter;
    private final Timer requestTimer;
    private volatile boolean shuttingDown;

    public HttpReceiverServlet(Stage.Context context, HttpReceiver receiver, BlockingQueue<Exception> errorQueue) {
        this.receiver = receiver;
        this.errorQueue = errorQueue;
        invalidRequestMeter = context.createMeter("invalidRequests");
        errorRequestMeter = context.createMeter("errorRequests");
        requestMeter = context.createMeter("requests");
        requestTimer = context.createTimer("requests");
    }

    protected HttpReceiver getReceiver() {
        return receiver;
    }

    // From https://stackoverflow.com/a/31928740/33905
    @VisibleForTesting
    protected static Map<String, String[]> getQueryParameters(HttpServletRequest request) {
        Map<String, String[]> queryParameters = new HashMap<>();
        String queryString = request.getQueryString();

        if (StringUtils.isEmpty(queryString)) {
            return queryParameters;
        }

        String[] parameters = queryString.split("&");

        for (String parameter : parameters) {
            String[] keyValuePair = parameter.split("=");
            String[] values = queryParameters.get(keyValuePair[0]);
            values = ArrayUtils.add(values, keyValuePair.length == 1 ? "" : keyValuePair[1]); //length is one if no value is available.
            queryParameters.put(keyValuePair[0], values);
        }
        return queryParameters;
    }

    @VisibleForTesting
    protected boolean validateAppId(HttpServletRequest req, HttpServletResponse res)
            throws ServletException, IOException {
        boolean valid = false;
        String ourAppId = null;
        try {
            ourAppId = getReceiver().getAppId().get();
        } catch (StageException e) {
            throw new IOException("Cant resolve credential value", e);
        }
        String requestor = req.getRemoteAddr() + ":" + req.getRemotePort();
        String reqAppId = req.getHeader(HttpConstants.X_SDC_APPLICATION_ID_HEADER);

        if (reqAppId == null && receiver.isAppIdViaQueryParamAllowed()) {
            reqAppId = getQueryParameters(req).get(HttpConstants.SDC_APPLICATION_ID_QUERY_PARAM)[0];
        }

        if (reqAppId == null) {
            LOG.warn("Request from '{}' missing appId, rejected", requestor);
            res.sendError(HttpServletResponse.SC_FORBIDDEN, "Missing 'appId'");
        } else if (!ourAppId.equals(reqAppId)) {
            LOG.warn("Request from '{}' invalid appId '{}', rejected", requestor, reqAppId);
            res.sendError(HttpServletResponse.SC_FORBIDDEN, "Invalid 'appId'");
        } else {
            valid = true;
        }
        return valid;
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException {
        if (validateAppId(req, res)) {
            LOG.debug("Validation from '{}', OK", req.getRemoteAddr());
            res.setHeader(HttpConstants.X_SDC_PING_HEADER, HttpConstants.X_SDC_PING_VALUE);
            res.setStatus(HttpServletResponse.SC_OK);
        }
    }

    @VisibleForTesting
    boolean validatePostRequest(HttpServletRequest req, HttpServletResponse res)
            throws ServletException, IOException {
        boolean valid = false;
        if (validateAppId(req, res)) {
            String compression = req.getHeader(HttpConstants.X_SDC_COMPRESSION_HEADER);
            if (compression == null) {
                valid = true;
            } else {
                switch (compression) {
                case HttpConstants.SNAPPY_COMPRESSION:
                    valid = true;
                    break;
                default:
                    String requestor = req.getRemoteAddr() + ":" + req.getRemotePort();
                    LOG.warn("Invalid compression '{}' in request from '{}', returning error", compression,
                            requestor);
                    res.sendError(HttpServletResponse.SC_UNSUPPORTED_MEDIA_TYPE,
                            "Unsupported compression: " + compression);
                    break;
                }
            }
        }
        return valid && getReceiver().validate(req, res);
    }

    @Override
    protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        String requestor = req.getRemoteAddr() + ":" + req.getRemotePort();
        if (isShuttingDown()) {
            LOG.debug("Shutting down, discarding incoming request from '{}'", requestor);
            resp.setStatus(HttpServletResponse.SC_GONE);
        } else {
            if (validatePostRequest(req, resp)) {
                long start = System.currentTimeMillis();
                LOG.debug("Request accepted from '{}'", requestor);
                try (InputStream in = req.getInputStream()) {
                    InputStream is = in;
                    String compression = req.getHeader(HttpConstants.X_SDC_COMPRESSION_HEADER);
                    if (compression == null) {
                        compression = req.getHeader(HttpConstants.CONTENT_ENCODING_HEADER);
                    }
                    if (compression != null) {
                        switch (compression) {
                        case HttpConstants.SNAPPY_COMPRESSION:
                            is = new SnappyFramedInputStream(is, true);
                            break;
                        case HttpConstants.GZIP_COMPRESSION:
                            is = new GZIPInputStream(is);
                            break;
                        default:
                            throw new IOException(
                                    Utils.format("It shouldn't happen, unexpected compression '{}'", compression));
                        }
                    }
                    LOG.debug("Processing request from '{}'", requestor);
                    processRequest(req, is, resp);
                } catch (Exception ex) {
                    errorQueue.offer(ex);
                    errorRequestMeter.mark();
                    LOG.warn("Error while processing request payload from '{}': {}", requestor, ex.toString(), ex);
                    resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, ex.toString());
                } finally {
                    requestTimer.update(System.currentTimeMillis() - start, TimeUnit.MILLISECONDS);
                }

            } else {
                invalidRequestMeter.mark();
            }
        }
    }

    protected void processRequest(HttpServletRequest req, InputStream is, HttpServletResponse resp)
            throws IOException {
        if (getReceiver().process(req, is, resp)) {
            resp.setStatus(HttpServletResponse.SC_OK);
            requestMeter.mark();
        } else {
            resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Record(s) didn't reach all destinations");
            errorRequestMeter.mark();
        }
    }

    @VisibleForTesting
    boolean isShuttingDown() {
        return shuttingDown;
    }

    public void setShuttingDown() {
        shuttingDown = true;
    }

}