co.cask.cdap.gateway.router.handlers.SecurityAuthenticationHttpHandler.java Source code

Java tutorial

Introduction

Here is the source code for co.cask.cdap.gateway.router.handlers.SecurityAuthenticationHttpHandler.java

Source

/*
 * Copyright  2014 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.cdap.gateway.router.handlers;

import co.cask.cdap.common.conf.CConfiguration;
import co.cask.cdap.common.conf.Constants;
import co.cask.cdap.common.logging.AuditLogEntry;
import co.cask.cdap.security.auth.AccessTokenTransformer;
import co.cask.cdap.security.auth.TokenState;
import co.cask.cdap.security.auth.TokenValidator;
import co.cask.cdap.security.server.GrantAccessToken;
import com.google.common.base.Charsets;
import com.google.common.base.Stopwatch;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import com.ning.org.jboss.netty.handler.codec.http.HttpConstants;
import org.apache.twill.discovery.Discoverable;
import org.apache.twill.discovery.DiscoveryServiceClient;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBufferIndexFinder;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.WriteCompletionEvent;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;

/**
 * Security handler that intercept HTTP message and validates the access token in
 * header Authorization field.
 */
public class SecurityAuthenticationHttpHandler extends SimpleChannelHandler {
    private static final Logger LOG = LoggerFactory.getLogger(SecurityAuthenticationHttpHandler.class);
    private static final Logger AUDIT_LOG = LoggerFactory.getLogger("http-access");

    private final TokenValidator tokenValidator;
    private final AccessTokenTransformer accessTokenTransformer;
    private final DiscoveryServiceClient discoveryServiceClient;
    private final Iterable<Discoverable> discoverables;
    private final CConfiguration configuration;
    private final String realm;
    private final Pattern bypassPattern;

    public SecurityAuthenticationHttpHandler(String realm, TokenValidator tokenValidator,
            CConfiguration configuration, AccessTokenTransformer accessTokenTransformer,
            DiscoveryServiceClient discoveryServiceClient) {
        this.realm = realm;
        this.tokenValidator = tokenValidator;
        this.accessTokenTransformer = accessTokenTransformer;
        this.discoveryServiceClient = discoveryServiceClient;
        this.discoverables = discoveryServiceClient.discover(Constants.Service.EXTERNAL_AUTHENTICATION);
        this.configuration = configuration;
        this.bypassPattern = createMatcher(
                configuration.get(Constants.Security.Router.BYPASS_AUTHENTICATION_REGEX));
    }

    private Pattern createMatcher(String s) {
        if (s == null) {
            return null;
        }
        try {
            return Pattern.compile(s);
        } catch (PatternSyntaxException e) {
            throw new IllegalArgumentException(String.format("Invalid regular expression for %s",
                    Constants.Security.Router.BYPASS_AUTHENTICATION_REGEX), e);
        }
    }

    private boolean matchBypassPattern(HttpRequest req) {
        return bypassPattern != null && bypassPattern.matcher(req.getUri()).matches();
    }

    /**
     * Intercepts the HttpMessage for getting the access token in authorization header
     * @param ctx channel handler context delegated from MessageReceived callback
     * @param msg intercepted HTTP message
     * @param inboundChannel
     * @return {@code true} if the HTTP message has valid Access token
     * @throws Exception
     */
    private boolean validateSecuredInterception(ChannelHandlerContext ctx, HttpRequest msg, Channel inboundChannel,
            AuditLogEntry logEntry) throws Exception {
        String auth = msg.getHeader(HttpHeaders.Names.AUTHORIZATION);
        String accessToken = null;

        /*
         * Parse the access token from authorization header.  The header will be in the form:
         *     Authorization: Bearer ACCESSTOKEN
         *
         * where ACCESSTOKEN is the base64 encoded serialized AccessToken instance.
         */
        if (auth != null) {
            int spIndex = auth.trim().indexOf(' ');
            if (spIndex != -1) {
                accessToken = auth.substring(spIndex + 1).trim();
            }
        }

        logEntry.setClientIP(((InetSocketAddress) ctx.getChannel().getRemoteAddress()).getAddress());
        logEntry.setRequestLine(msg.getMethod(), msg.getUri(), msg.getProtocolVersion());

        TokenState tokenState = tokenValidator.validate(accessToken);
        if (!tokenState.isValid()) {
            HttpResponse httpResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1,
                    HttpResponseStatus.UNAUTHORIZED);
            logEntry.setResponseCode(HttpResponseStatus.UNAUTHORIZED.getCode());

            JsonObject jsonObject = new JsonObject();
            if (tokenState == TokenState.MISSING) {
                httpResponse.addHeader(HttpHeaders.Names.WWW_AUTHENTICATE,
                        String.format("Bearer realm=\"%s\"", realm));
                LOG.debug("Authentication failed due to missing token");

            } else {
                httpResponse.addHeader(HttpHeaders.Names.WWW_AUTHENTICATE,
                        String.format("Bearer realm=\"%s\" error=\"invalid_token\"" + " error_description=\"%s\"",
                                realm, tokenState.getMsg()));
                jsonObject.addProperty("error", "invalid_token");
                jsonObject.addProperty("error_description", tokenState.getMsg());
                LOG.debug("Authentication failed due to invalid token, reason={};", tokenState);
            }
            JsonArray externalAuthenticationURIs = new JsonArray();

            //Waiting for service to get discovered
            stopWatchWait(externalAuthenticationURIs);

            jsonObject.add("auth_uri", externalAuthenticationURIs);

            ChannelBuffer content = ChannelBuffers.wrappedBuffer(jsonObject.toString().getBytes(Charsets.UTF_8));
            httpResponse.setContent(content);
            int contentLength = content.readableBytes();
            httpResponse.setHeader(HttpHeaders.Names.CONTENT_LENGTH, contentLength);
            httpResponse.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/json;charset=UTF-8");
            logEntry.setResponseContentLength(new Long(contentLength));
            ChannelFuture writeFuture = Channels.future(inboundChannel);
            Channels.write(ctx, writeFuture, httpResponse);
            writeFuture.addListener(ChannelFutureListener.CLOSE);
            return false;
        } else {
            AccessTokenTransformer.AccessTokenIdentifierPair accessTokenIdentifierPair = accessTokenTransformer
                    .transform(accessToken);
            logEntry.setUserName(accessTokenIdentifierPair.getAccessTokenIdentifierObj().getUsername());
            msg.setHeader(HttpHeaders.Names.AUTHORIZATION,
                    "CDAP-verified " + accessTokenIdentifierPair.getAccessTokenIdentifierStr());
            msg.setHeader(Constants.Security.Headers.USER_ID,
                    accessTokenIdentifierPair.getAccessTokenIdentifierObj().getUsername());
            msg.setHeader(Constants.Security.Headers.USER_IP,
                    ((InetSocketAddress) ctx.getChannel().getRemoteAddress()).getAddress().getHostAddress());
            return true;
        }
    }

    /**
     *
     * @param externalAuthenticationURIs the list that should be populated with discovered with
     *                                   external auth servers URIs
     * @throws Exception
     */
    private void stopWatchWait(JsonArray externalAuthenticationURIs) throws Exception {
        boolean done = false;
        Stopwatch stopwatch = new Stopwatch();
        stopwatch.start();
        String protocol;
        int port;
        if (configuration.getBoolean(Constants.Security.SSL_ENABLED)) {
            protocol = "https";
            port = configuration.getInt(Constants.Security.AuthenticationServer.SSL_PORT);
        } else {
            protocol = "http";
            port = configuration.getInt(Constants.Security.AUTH_SERVER_BIND_PORT);
        }

        do {
            for (Discoverable d : discoverables) {
                String url = String.format("%s://%s:%d/%s", protocol, d.getSocketAddress().getHostName(), port,
                        GrantAccessToken.Paths.GET_TOKEN);
                externalAuthenticationURIs.add(new JsonPrimitive(url));
                done = true;
            }
            if (!done) {
                TimeUnit.MILLISECONDS.sleep(200);
            }
        } while (!done && stopwatch.elapsedTime(TimeUnit.SECONDS) < 2L);
    }

    @Override
    public void messageReceived(ChannelHandlerContext ctx, final MessageEvent event) throws Exception {
        Object msg = event.getMessage();
        if (!(msg instanceof HttpRequest)) {
            super.messageReceived(ctx, event);
        } else {
            AuditLogEntry logEntry = new AuditLogEntry();
            ctx.setAttachment(logEntry);
            HttpRequest req = (HttpRequest) msg;
            if (matchBypassPattern(req) || validateSecuredInterception(ctx, req, event.getChannel(), logEntry)) {
                Channels.fireMessageReceived(ctx, msg, event.getRemoteAddress());
            }
            // we write the response directly for authentication failure, so nothing to do for else
        }
    }

    @Override
    public void writeRequested(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
        AuditLogEntry logEntry = getLogEntry(ctx);
        Object message = e.getMessage();
        if (message instanceof HttpResponse) {
            HttpResponse response = (HttpResponse) message;
            logEntry.setResponseCode(response.getStatus().getCode());
            if (response.containsHeader(HttpHeaders.Names.CONTENT_LENGTH)) {
                String lengthString = response.getHeader(HttpHeaders.Names.CONTENT_LENGTH);
                try {
                    logEntry.setResponseContentLength(Long.valueOf(lengthString));
                } catch (NumberFormatException nfe) {
                    LOG.warn("Invalid value for content length in HTTP response message: {}", lengthString, nfe);
                }
            }
        } else if (message instanceof ChannelBuffer) {
            // for chunked responses the response code will only be present on the first chunk
            // so we only look for it the first time around
            if (logEntry.getResponseCode() == null) {
                ChannelBuffer channelBuffer = (ChannelBuffer) message;
                logEntry.setResponseCode(findResponseCode(channelBuffer));
                if (logEntry.getResponseCode() != null) {
                    // we currently only look for a Content-Length header in the first buffer on an HTTP response
                    // this is a limitation of the implementation that simplifies header parsing
                    logEntry.setResponseContentLength(findContentLength(channelBuffer));
                }
            }
        } else {
            LOG.debug("Unhandled response message type: {}", message.getClass());
        }
        super.writeRequested(ctx, e);
    }

    @Override
    public void writeComplete(ChannelHandlerContext ctx, WriteCompletionEvent e) throws Exception {
        AuditLogEntry logEntry = getLogEntry(ctx);
        if (!logEntry.isLogged()) {
            AUDIT_LOG.trace(logEntry.toString());
            logEntry.setLogged(true);
        }
    }

    private Integer findResponseCode(ChannelBuffer buffer) {
        Integer responseCode = null;

        // we assume that the response code should follow the first space in the first line of the response
        int indx = buffer.indexOf(buffer.readerIndex(), buffer.writerIndex(),
                ChannelBufferIndexFinder.LINEAR_WHITESPACE);
        if (indx >= 0 && indx < buffer.writerIndex() - 4) {
            String codeString = buffer.toString(indx, 4, Charsets.UTF_8).trim();
            try {
                responseCode = Integer.valueOf(codeString);
            } catch (NumberFormatException nfe) {
                LOG.warn("Invalid value for HTTP response code: {}", codeString, nfe);
            }
        } else {
            LOG.debug("Invalid index for space in response: index={}, buffer size={}", indx,
                    buffer.readableBytes());
        }
        return responseCode;
    }

    private Long findContentLength(ChannelBuffer buffer) {
        Long contentLength = null;
        int bufferEnd = buffer.writerIndex();
        int index = buffer.indexOf(buffer.readerIndex(), bufferEnd, CONTENT_LENGTH_FINDER);
        if (index >= 0) {
            // find the following ':'
            int colonIndex = buffer.indexOf(index, bufferEnd, HttpConstants.COLON);
            int eolIndex = buffer.indexOf(index, bufferEnd, ChannelBufferIndexFinder.CRLF);
            if (colonIndex > 0 && colonIndex < eolIndex) {
                String lengthString = buffer.toString(colonIndex + 1, eolIndex - (colonIndex + 1), Charsets.UTF_8)
                        .trim();
                try {
                    contentLength = Long.valueOf(lengthString);
                } catch (NumberFormatException nfe) {
                    LOG.warn("Invalid value for content length in HTTP response message: {}", lengthString, nfe);
                }
            }
        }
        return contentLength;
    }

    private AuditLogEntry getLogEntry(ChannelHandlerContext ctx) {
        Object entryObject = ctx.getAttachment();
        AuditLogEntry logEntry;
        if (entryObject != null && entryObject instanceof AuditLogEntry) {
            logEntry = (AuditLogEntry) entryObject;
        } else {
            logEntry = new AuditLogEntry();
            ctx.setAttachment(logEntry);
        }
        return logEntry;
    }

    private static final ChannelBufferIndexFinder CONTENT_LENGTH_FINDER = new ChannelBufferIndexFinder() {
        private byte[] headerName = HttpHeaders.Names.CONTENT_LENGTH.getBytes(Charsets.UTF_8);

        @Override
        public boolean find(ChannelBuffer buffer, int guessedIndex) {
            if (buffer.capacity() - guessedIndex < headerName.length) {
                return false;
            }

            for (int i = 0; i < headerName.length; i++) {
                if (headerName[i] != buffer.getByte(guessedIndex + i)) {
                    return false;
                }
            }
            return true;
        }
    };
}