io.apiman.gateway.platforms.servlet.connectors.HttpApiConnection.java Source code

Java tutorial

Introduction

Here is the source code for io.apiman.gateway.platforms.servlet.connectors.HttpApiConnection.java

Source

/*
 * Copyright 2014 JBoss Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.apiman.gateway.platforms.servlet.connectors;

import io.apiman.common.config.options.BasicAuthOptions;
import io.apiman.gateway.engine.IApiConnection;
import io.apiman.gateway.engine.IApiConnectionResponse;
import io.apiman.gateway.engine.async.AsyncResultImpl;
import io.apiman.gateway.engine.async.IAsyncHandler;
import io.apiman.gateway.engine.async.IAsyncResultHandler;
import io.apiman.gateway.engine.auth.RequiredAuthType;
import io.apiman.gateway.engine.beans.Api;
import io.apiman.gateway.engine.beans.ApiRequest;
import io.apiman.gateway.engine.beans.ApiResponse;
import io.apiman.gateway.engine.beans.exceptions.ConnectorException;
import io.apiman.gateway.engine.io.ByteBuffer;
import io.apiman.gateway.engine.io.IApimanBuffer;
import io.apiman.gateway.platforms.servlet.GatewayThreadContext;
import io.apiman.gateway.platforms.servlet.connectors.ok.OkUrlFactory;
import io.apiman.gateway.platforms.servlet.connectors.ssl.SSLSessionStrategy;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.io.IOUtils;

import com.squareup.okhttp.OkHttpClient;

/**
 * Models a live connection to a back end API.
 *
 * @author eric.wittmann@redhat.com
 */
public class HttpApiConnection implements IApiConnection, IApiConnectionResponse {

    private static final Set<String> SUPPRESSED_REQUEST_HEADERS = new HashSet<>();
    private static final Set<String> SUPPRESSED_RESPONSE_HEADERS = new HashSet<>();
    static {
        SUPPRESSED_REQUEST_HEADERS.add("Transfer-Encoding"); //$NON-NLS-1$
        SUPPRESSED_REQUEST_HEADERS.add("Content-Length"); //$NON-NLS-1$
        SUPPRESSED_REQUEST_HEADERS.add("X-API-Key"); //$NON-NLS-1$
        SUPPRESSED_REQUEST_HEADERS.add("Host"); //$NON-NLS-1$

        SUPPRESSED_RESPONSE_HEADERS.add("OkHttp-Received-Millis"); //$NON-NLS-1$
        SUPPRESSED_RESPONSE_HEADERS.add("OkHttp-Response-Source"); //$NON-NLS-1$
        SUPPRESSED_RESPONSE_HEADERS.add("OkHttp-Selected-Protocol"); //$NON-NLS-1$
        SUPPRESSED_RESPONSE_HEADERS.add("OkHttp-Sent-Millis"); //$NON-NLS-1$
    }

    private ApiRequest request;
    private Api api;
    private RequiredAuthType requiredAuthType;
    private SSLSessionStrategy sslStrategy;
    private IAsyncResultHandler<IApiConnectionResponse> responseHandler;

    private boolean connected;
    private HttpURLConnection connection;
    private OutputStream outputStream;

    private IAsyncHandler<IApimanBuffer> bodyHandler;
    private IAsyncHandler<Void> endHandler;

    private ApiResponse response;
    final private OkHttpClient client;

    /**
     * Constructor.
     *
     * @param client the http client to use
     * @param sslStrategy the SSL strategy
     * @param request the request
     * @param api the API
     * @param requiredAuthType the authorization type
     * @param handler the result handler
     * @throws ConnectorException when unable to connect
     */
    public HttpApiConnection(OkHttpClient client, ApiRequest request, Api api, RequiredAuthType requiredAuthType,
            SSLSessionStrategy sslStrategy, IAsyncResultHandler<IApiConnectionResponse> handler)
            throws ConnectorException {
        this.client = client;
        this.request = request;
        this.api = api;
        this.requiredAuthType = requiredAuthType;
        this.sslStrategy = sslStrategy;
        this.responseHandler = handler;

        try {
            connect();
        } catch (Exception e) {
            handler.handle(AsyncResultImpl.<IApiConnectionResponse>create(e));
        }
    }

    /**
     * Connects to the back end system.
     */
    private void connect() throws ConnectorException {
        try {
            Set<String> suppressedHeaders = new HashSet<>(SUPPRESSED_REQUEST_HEADERS);

            String endpoint = api.getEndpoint();
            if (endpoint.endsWith("/")) { //$NON-NLS-1$
                endpoint = endpoint.substring(0, endpoint.length() - 1);
            }
            if (request.getDestination() != null) {
                endpoint += request.getDestination();
            }
            if (request.getQueryParams() != null && !request.getQueryParams().isEmpty()) {
                String delim = "?"; //$NON-NLS-1$
                for (Entry<String, String> entry : request.getQueryParams()) {
                    endpoint += delim + entry.getKey();
                    if (entry.getValue() != null) {
                        endpoint += "=" + URLEncoder.encode(entry.getValue(), "UTF-8"); //$NON-NLS-1$ //$NON-NLS-2$
                    }
                    delim = "&"; //$NON-NLS-1$
                }
            }
            URL url = new URL(endpoint);
            OkUrlFactory factory = new OkUrlFactory(client);
            connection = factory.open(url);

            boolean isSsl = connection instanceof HttpsURLConnection;

            if (requiredAuthType == RequiredAuthType.MTLS && !isSsl) {
                throw new ConnectorException(
                        "Mutually authenticating TLS requested, but insecure endpoint protocol was indicated."); //$NON-NLS-1$
            }

            if (requiredAuthType == RequiredAuthType.BASIC) {
                BasicAuthOptions options = new BasicAuthOptions(api.getEndpointProperties());
                if (options.getUsername() != null && options.getPassword() != null) {
                    if (options.isRequireSSL() && !isSsl) {
                        throw new ConnectorException(
                                "Endpoint security requested (BASIC auth) but endpoint is not secure (SSL)."); //$NON-NLS-1$
                    }

                    String up = options.getUsername() + ':' + options.getPassword();
                    StringBuilder builder = new StringBuilder();
                    builder.append("Basic "); //$NON-NLS-1$
                    builder.append(Base64.encodeBase64String(up.getBytes()));
                    connection.setRequestProperty("Authorization", builder.toString()); //$NON-NLS-1$
                    suppressedHeaders.add("Authorization"); //$NON-NLS-1$
                }
            }

            if (isSsl) {
                HttpsURLConnection https = (HttpsURLConnection) connection;
                SSLSocketFactory socketFactory = sslStrategy.getSocketFactory();
                https.setSSLSocketFactory(socketFactory);
                https.setHostnameVerifier(sslStrategy.getHostnameVerifier());
            }

            setConnectTimeout(connection);
            setReadTimeout(connection);
            if (request.getType().equalsIgnoreCase("PUT") || request.getType().equalsIgnoreCase("POST")) { //$NON-NLS-1$ //$NON-NLS-2$
                connection.setDoOutput(true);
            } else {
                connection.setDoOutput(false);
            }
            connection.setDoInput(true);
            connection.setUseCaches(false);
            connection.setRequestMethod(request.getType());

            // Set the request headers
            for (Entry<String, String> entry : request.getHeaders()) {
                String hname = entry.getKey();
                String hval = entry.getValue();
                if (!suppressedHeaders.contains(hname)) {
                    connection.setRequestProperty(hname, hval);
                }
            }

            // Set or reset mandatory headers
            connection.setRequestProperty("Host", url.getHost() + determinePort(url)); //$NON-NLS-1$
            connection.connect();
            connected = true;
        } catch (IOException e) {
            throw new ConnectorException(e);
        }
    }

    /**
     * If the endpoint properties includes a connect timeout override, then 
     * set it here.
     * @param connection
     */
    private void setConnectTimeout(HttpURLConnection connection) {
        try {
            Map<String, String> endpointProperties = this.api.getEndpointProperties();
            if (endpointProperties.containsKey("timeouts.connect")) { //$NON-NLS-1$
                int connectTimeoutMs = new Integer(endpointProperties.get("timeouts.connect")); //$NON-NLS-1$
                connection.setConnectTimeout(connectTimeoutMs);
            }
        } catch (Throwable t) {
        }
    }

    /**
     * If the endpoint properties includes a read timeout override, then 
     * set it here.
     * @param connection
     */
    private void setReadTimeout(HttpURLConnection connection) {
        try {
            Map<String, String> endpointProperties = this.api.getEndpointProperties();
            if (endpointProperties.containsKey("timeouts.read")) { //$NON-NLS-1$
                int connectTimeoutMs = new Integer(endpointProperties.get("timeouts.read")); //$NON-NLS-1$
                connection.setReadTimeout(connectTimeoutMs);
            }
        } catch (Throwable t) {
        }
    }

    /**
     * Extracts the port information fromthe given URL.
     * @param url a URL
     * @return the port configured in the URL, or empty string if no port specified
     */
    private String determinePort(URL url) {
        return (url.getPort() == -1) ? "" : ":" + url.getPort(); //$NON-NLS-1$ //$NON-NLS-2$
    }

    /**
     * @see io.apiman.gateway.engine.io.IReadStream#bodyHandler(io.apiman.gateway.engine.async.IAsyncHandler)
     */
    @Override
    public void bodyHandler(IAsyncHandler<IApimanBuffer> bodyHandler) {
        this.bodyHandler = bodyHandler;
    }

    /**
     * @see io.apiman.gateway.engine.io.IReadStream#endHandler(io.apiman.gateway.engine.async.IAsyncHandler)
     */
    @Override
    public void endHandler(IAsyncHandler<Void> endHandler) {
        this.endHandler = endHandler;
    }

    /**
     * @see io.apiman.gateway.engine.io.IReadStream#getHead()
     */
    @Override
    public ApiResponse getHead() {
        return response;
    }

    /**
     * @see io.apiman.gateway.engine.io.IStream#isFinished()
     */
    @Override
    public boolean isFinished() {
        return !connected;
    }

    /**
     * @see io.apiman.gateway.engine.IApiConnection#isConnected()
     */
    @Override
    public boolean isConnected() {
        return connected;
    }

    /**
     * @see io.apiman.gateway.engine.io.IAbortable#abort()
     */
    @Override
    public void abort() {
        try {
            if (!connected) {
                throw new IOException("Not connected."); //$NON-NLS-1$
            }
            if (connection != null) {
                try {
                    IOUtils.closeQuietly(outputStream);
                    IOUtils.closeQuietly(connection.getInputStream());
                } catch (Exception e) {
                }
                try {
                    connected = false;
                    connection.disconnect();
                } catch (Exception e) {
                }
            }
        } catch (IOException e) {
            // TODO log this error but don't rethrow it
        }
    }

    /**
     * @see io.apiman.gateway.engine.io.IWriteStream#write(io.apiman.gateway.engine.io.IApimanBuffer)
     */
    @Override
    public void write(IApimanBuffer chunk) {
        try {
            if (!connected) {
                throw new IOException("Not connected."); //$NON-NLS-1$
            }
            if (outputStream == null) {
                outputStream = connection.getOutputStream();
            }
            if (chunk instanceof ByteBuffer) {
                byte[] buffer = (byte[]) chunk.getNativeBuffer();
                outputStream.write(buffer, 0, chunk.length());
            } else {
                outputStream.write(chunk.getBytes());
            }
        } catch (IOException e) {
            // TODO log this error.
            throw new ConnectorException(e);
        }
    }

    /**
     * @see io.apiman.gateway.engine.io.IWriteStream#end()
     */
    @Override
    public void end() {
        try {
            if (!connected) {
                throw new IOException("Not connected."); //$NON-NLS-1$
            }
            IOUtils.closeQuietly(outputStream);
            outputStream = null;
            // Process the response, convert to an ApiResponse object, and return it
            response = GatewayThreadContext.getApiResponse();
            Map<String, List<String>> headerFields = connection.getHeaderFields();
            for (String headerName : headerFields.keySet()) {
                if (headerName != null && !SUPPRESSED_RESPONSE_HEADERS.contains(headerName)) {
                    response.getHeaders().add(headerName, connection.getHeaderField(headerName));
                }
            }
            response.setCode(connection.getResponseCode());
            response.setMessage(connection.getResponseMessage());
            responseHandler.handle(AsyncResultImpl.<IApiConnectionResponse>create(this));
        } catch (Exception e) {
            // TODO log this error
            throw new ConnectorException(e);
        }
    }

    /**
     * @see io.apiman.gateway.engine.io.ISignalReadStream#transmit()
     */
    @Override
    public void transmit() {
        try {
            if (!connected) {
                throw new IOException("Not connected."); //$NON-NLS-1$
            }
            InputStream is = connection.getInputStream();
            ByteBuffer buffer = new ByteBuffer(2048);
            int numBytes = buffer.readFrom(is);
            while (numBytes != -1) {
                bodyHandler.handle(buffer);
                numBytes = buffer.readFrom(is);
            }
            IOUtils.closeQuietly(is);
            connection.disconnect();
            connected = false;
            endHandler.handle(null);
        } catch (Throwable e) {
            // At this point we're sort of screwed, because we've already sent the response to
            // the originating client - and we're in the process of sending the body data. So
            // I guess the only thing to do is abort() the connection and cross our fingers.
            if (connected) {
                abort();
            }
            throw new RuntimeException(e);
        }
    }

}