edu.lternet.pasta.gatekeeper.GatekeeperFilter.java Source code

Java tutorial

Introduction

Here is the source code for edu.lternet.pasta.gatekeeper.GatekeeperFilter.java

Source

/*
 *
 * $Date$ $Author$ $Revision$
 *
 * Copyright 2010 the University of New Mexico.
 *
 * This work was supported by National Science Foundation Cooperative Agreements
 * #DEB-0832652 and #DEB-0936498.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at http://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package edu.lternet.pasta.gatekeeper;

import java.io.*;
import java.security.*;
import java.security.cert.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.HttpHeaders;

import org.apache.log4j.Logger;

import org.apache.commons.codec.binary.Base64;

import edu.lternet.pasta.common.security.access.UnauthorizedException;
import edu.lternet.pasta.common.security.auth.AuthSystemDef;
import edu.lternet.pasta.common.security.auth.KnbAuthSystem;
import edu.lternet.pasta.common.security.auth.SymmetricEncrypter;
import edu.lternet.pasta.common.security.token.AuthToken;
import edu.lternet.pasta.common.security.token.AuthTokenFactory;
import edu.lternet.pasta.common.security.token.AuthTokenWithPassword;
import edu.lternet.pasta.common.security.token.BasicAuthToken;

/**
 * <p>
 * The Gatekeeper web service handles all authentication from incoming requests.
 * </p>
 *
 * <p>
 * If the user submits only BASIC authentication credentials, a token will be
 * generated and returned upon completion of the requested query.
 * </p>
 *
 * <p>
 * If the user submits a token, the token will be used provided it does not
 * exceed the time to live. In that event, a ServletException is thrown.
 * </p>
 *
 * <p>
 * If no credentials or tokens are submitted, a token for special user public
 * will be created and the remainder of the query will be done as public. The
 * response will return a public token.
 * </p>
 *
 * @webservicename Gatekeeper
 * @baseurl https://pasta.lternet.edu/
 */
public final class GatekeeperFilter implements Filter {

    private static Logger logger = Logger.getLogger(GatekeeperFilter.class);
    private FilterConfig filterConfig;
    private static final int BAD_REQUEST_CODE = 400;
    private static final int UNAUTHORIZED_CODE = 401;

    private enum CookieUse {
        EXTERNAL, INTERNAL
    }

    /**
     * Overridden init method that sets the filterConfig.
     */
    @Override
    public void init(FilterConfig config) throws ServletException {
        filterConfig = config;
    }

    /**
     * Overridden destroy method that free's the filterConfig.
     */
    @Override
    public void destroy() {
        filterConfig = null;
    }

    /**
     * Overridden doFilter method.
     * @param request ServletRequest representing the incoming user http(s)
     *                request.
     * @param request ServletResponse representing the associated response
     *                                that will eventually be passed on to the
     *                                next servlet.
     */
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {

        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse res = (HttpServletResponse) response;

        // Output HttpServletRequest diagnostic information
        logger.info("Request URL: " + req.getMethod() + " - " + req.getRequestURL().toString());

        doDiagnostics(req);

        try {
            boolean hasAuthToken = hasAuthToken(req.getCookies());
            Cookie internalCookie = hasAuthToken ? doCookie(req) : doHeader(req, res);
            chain.doFilter(new PastaRequestWrapper(req, internalCookie), res);
        } catch (IllegalStateException e) {
            res.setStatus(BAD_REQUEST_CODE);
            PrintWriter out = res.getWriter();
            out.println(e);
        } catch (UnauthorizedException e) {
            res.setStatus(UNAUTHORIZED_CODE);
            PrintWriter out = res.getWriter();
            out.println(e.getMessage());
        } catch (IllegalArgumentException e) {
            res.setStatus(UNAUTHORIZED_CODE);
            PrintWriter out = res.getWriter();
            out.println(e.getMessage());
        }

    }

    /*
     *  Process incoming authentication token
     */
    private Cookie doCookie(HttpServletRequest req)
            throws IllegalArgumentException, IllegalStateException, UnauthorizedException {

        String authToken = null;
        String authTokenStr = retrieveAuthTokenString(req.getCookies());

        if (authTokenStr == null) {
            String gripe = "Authentication token not found!";
            throw new IllegalStateException(gripe);
        } else {

            String[] authTokeStrParts = authTokenStr.split("-");
            authToken = authTokeStrParts[0];
            byte[] signature = Base64.decodeBase64(authTokeStrParts[1]);

            if (!isValidSignature(authToken, signature)) {
                String gripe = "Authentication token is not valid!";
                throw new IllegalStateException(gripe);
            }

        }

        AuthToken token = null;
        token = AuthTokenFactory.makeCookieAuthToken(authToken);
        assertTimeToLive(token);

        return makeAuthTokenCookie(token, CookieUse.INTERNAL);

    }

    /*
     *  Process incoming basic-authentication header or "public" user
     */
    private Cookie doHeader(HttpServletRequest req, HttpServletResponse res) {
        String rawHeader = req.getHeader(HttpHeaders.AUTHORIZATION);
        boolean publicUser = (rawHeader == null || rawHeader.isEmpty());
        AuthToken authToken = makeAuthenticated(rawHeader);

        // Only return authToken (in cookie) if real user
        if (!publicUser) {
            Cookie externalCookie = makeAuthTokenCookie(authToken, CookieUse.EXTERNAL);
            res.addCookie(externalCookie);
        }

        return makeAuthTokenCookie(authToken, CookieUse.INTERNAL);
    }

    private void assertTimeToLive(AuthToken attrlist) throws UnauthorizedException {

        if (attrlist == null) {
            String s = "Token not found.";
            throw new UnauthorizedException(s);
        }
        long ttl = attrlist.getExpirationDate() - (new Date().getTime());
        if (ttl < 1) {
            String s = "Token has expired.";
            throw new UnauthorizedException(s);
        }
    }

    private boolean hasAuthToken(Cookie[] cookies) {
        if (retrieveAuthTokenString(cookies) == null)
            return false;
        return true;
    }

    private AuthToken decryptToken(String tokenStr) throws IllegalStateException {

        String errorMsg = "Invalid AuthToken Submitted.";

        if (tokenStr == null || tokenStr.isEmpty()) {
            throw new IllegalStateException(errorMsg);
        }

        String decrypted = null;
        try {
            decrypted = SymmetricEncrypter.decrypt(tokenStr, ConfigurationListener.getPrivateKey());
        } catch (IllegalArgumentException e) {
            throw new IllegalStateException(errorMsg);
        }

        return AuthTokenFactory.makeCookieAuthToken(decrypted);
    }

    private String retrieveAuthTokenString(Cookie[] cookies) {

        /* no cookies */
        if (cookies == null)
            return null;
        for (Cookie c : cookies) {
            if (c.getName().equals(ConfigurationListener.getTokenName())) {
                /* found correct cookie */
                return c.getValue();
            }
        }
        return null;
    }

    private AuthToken makeAuthenticated(String rawHeader) {

        String tmpHeader = null;
        if (rawHeader == null || rawHeader.isEmpty()) {
            tmpHeader = BasicAuthToken.makeTokenString(ConfigurationListener.getPublicUser(),
                    ConfigurationListener.getPublicUser());
        } else {
            tmpHeader = rawHeader;
        }

        KnbAuthSystem knb = new KnbAuthSystem(ConfigurationListener.getLdapKeyStore());

        AuthTokenWithPassword basicToken = AuthTokenFactory.makeAuthTokenWithPassword(tmpHeader);
        String user = basicToken.getUserId();
        String password = basicToken.getPassword();

        Set<String> groups = new HashSet<String>();
        if (!user.equals(ConfigurationListener.getPublicUser())) {

            if (!knb.authenticate(user, password)) {
                String s = "The user '" + user + "' could not be authenticated " + "using the LTER LDAP server.";
                throw new UnauthorizedException(s); // Handle this better
            }
            // groups = knb.getGroups(user); // No groups currently stored here
            groups.add(ConfigurationListener.getAuthGroup());
        }
        AuthSystemDef authSystem = knb.getAuthSystemDef();
        long expirationDate = new Date().getTime() + ConfigurationListener.getTokenTtl();
        AuthToken token = AuthTokenFactory.makeCookieAuthToken(user, authSystem, expirationDate, groups);

        return token;
    }

    private Cookie makeAuthTokenCookie(AuthToken attrlist, CookieUse use) {

        String cookieValue = attrlist.getTokenString();

        if (use == CookieUse.EXTERNAL) {
            // Generate digital signature and add to token string
            byte[] signature = generateSignature(cookieValue);
            cookieValue = cookieValue + "-" + Base64.encodeBase64String(signature);
        }

        logger.debug("Cookie value: " + cookieValue);

        Cookie c = new Cookie(ConfigurationListener.getTokenName(), cookieValue);
        Long expiry = attrlist.getExpirationDate() / 1000L;
        c.setMaxAge(expiry.intValue());
        return c;

    }

    /**
     * dumpHeader iterates through all request headers and lists both the header
     * name and its contents to the designated logger.
     * 
     * @param req
     *          the HttpServletRequest object.
     * @return contentLength  
     *          the content length that was specified in the 
     *          request headers, possibly null
     */
    private Integer dumpHeader(HttpServletRequest req, Boolean noAuthPeek) {
        Enumeration<String> headerNames = req.getHeaderNames();
        String headerName = null;
        Integer contentLength = null;

        String header = null;
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Header: %n"));

        while (headerNames.hasMoreElements()) {

            headerName = headerNames.nextElement();
            header = req.getHeader(headerName);

            if (headerName.equals("Authorization") && noAuthPeek)
                header = "********";

            if (headerName.equals("Content-Length")) {
                contentLength = Integer.valueOf(header);
            }

            sb.append(String.format("     %s: %s%n", headerName, header));

        }

        logger.info(sb.toString());
        return contentLength;

    }

    /**
     * dumpBody outputs the contents of the request message body to the
     * designated logger.  Note that the use of this method will render the
     * request object inoperable for and subsequent calls.
     * 
     * @param req 
     *          the HttpServletRequest object.
     * @param contentLength 
     *          the content length that was specified in the 
     *          request headers, possibly null
     */
    private void dumpBody(HttpServletRequest req, Integer contentLength) {

        if (contentLength != null) {

            try {
                BufferedReader br = req.getReader();
                String line = null;

                logger.info("Request message body:\n");

                if (br.markSupported()) {

                    br.mark(contentLength + 1);

                    while ((line = br.readLine()) != null) {
                        System.out.println(line);
                    }

                    br.reset();

                }

                br.close();

            } catch (IOException e) {
                logger.error("dumpBody: " + e);
                e.printStackTrace();
            }

        }

    }

    private void doDiagnostics(HttpServletRequest req) {

        String remoteAddr = req.getRemoteAddr();
        logger.info("Remote address: " + remoteAddr);

        String requestUri = req.getRequestURI();
        logger.info("Request URI: " + requestUri);

        Boolean noAuthPeek = true;

        Integer contentLength = dumpHeader(req, noAuthPeek);
        //dumpBody(req, contentLength);

    }

    /*
     * Generate MD5withRSA digital signature for tokenString and return base64
     * encoded signature as a string.
     */
    private byte[] generateSignature(String tokenString) {

        byte[] signature = null;

        File ksFile = ConfigurationListener.getLterKeyStore();
        String ksType = ConfigurationListener.getLterKeyStoreType();
        String ksAlias = ConfigurationListener.getLterKeyStoreAlias();
        char[] storePass = ConfigurationListener.getLterStorePasswd().toCharArray();
        char[] keyPass = ConfigurationListener.getLterKeyPasswd().toCharArray();

        try {

            KeyStore ks = KeyStore.getInstance(ksType);
            FileInputStream ksFis = new FileInputStream(ksFile);
            BufferedInputStream ksBufIn = new BufferedInputStream(ksFis);

            ks.load(ksBufIn, storePass);
            PrivateKey priv = (PrivateKey) ks.getKey(ksAlias, keyPass);

            Signature rsa = Signature.getInstance("MD5withRSA");
            rsa.initSign(priv);

            rsa.update(tokenString.getBytes());
            signature = rsa.sign();

        } catch (Exception e) {
            logger.error(e.getMessage());
            e.printStackTrace();
        }

        return signature;

    }

    private void writeSignature(String tokenString, byte[] signature) {

        String signatureDir = ConfigurationListener.getSignatureDir();
        String signatureFile = signatureDir + tokenString;

        FileOutputStream sigFOS = null;

        try {
            sigFOS = new java.io.FileOutputStream(signatureFile);
            sigFOS.write(signature);
            sigFOS.close();
        } catch (FileNotFoundException e) {
            logger.error("Gatekeeper.writeSignature: " + e.getMessage());
            e.printStackTrace();
        } catch (IOException e) {
            logger.error("Gatekeeper.writeSignature: " + e.getMessage());
            e.printStackTrace();
        }

    }

    private Boolean isValidSignature(String tokenString, byte[] signature) {

        Boolean isValid = false;

        File lterCert = ConfigurationListener.getLterCertificate();

        try {

            FileInputStream certFis = new FileInputStream(lterCert);
            CertificateFactory cf = CertificateFactory.getInstance("X.509");
            X509Certificate cert = (X509Certificate) cf.generateCertificate(certFis);

            PublicKey pubKey = cert.getPublicKey();

            Signature sig = Signature.getInstance("MD5withRSA");
            sig.initVerify(pubKey);

            sig.update(tokenString.getBytes());
            isValid = sig.verify(signature);

        } catch (FileNotFoundException e) {
            logger.error("Gatekeeper.validateSignature :" + e.getMessage());
            e.printStackTrace();
        } catch (CertificateException e) {
            logger.error("Gatekeeper.validateSignature :" + e.getMessage());
            e.printStackTrace();
        } catch (NoSuchAlgorithmException e) {
            logger.error("Gatekeeper.validateSignature :" + e.getMessage());
            e.printStackTrace();
        } catch (InvalidKeyException e) {
            logger.error("Gatekeeper.validateSignature :" + e.getMessage());
            e.printStackTrace();
        } catch (SignatureException e) {
            logger.error("Gatekeeper.validateSignature :" + e.getMessage());
            e.printStackTrace();
        }

        return isValid;

    }

    public static class PastaRequestWrapper extends HttpServletRequestWrapper {

        private static Logger logger = Logger.getLogger(PastaRequestWrapper.class);
        private Cookie cookie;

        public PastaRequestWrapper(HttpServletRequest request, Cookie cookie) {

            super(request);
            this.cookie = cookie;
        }

        public String getHeader(String name) {

            if (name.equals(HttpHeaders.AUTHORIZATION))
                return null;
            String header = super.getHeader(name);
            if (name.equals(HttpHeaders.COOKIE) && header.isEmpty() && (cookie != null))
                return cookie.getName();

            return header;
        }

        public Enumeration<String> getHeaders(String name) {

            Enumeration<String> enumStr = super.getHeaders(name);

            if (name.equals(HttpHeaders.AUTHORIZATION)) {
                List<String> ls = new ArrayList<String>();
                enumStr = Collections.enumeration(ls);
            }

            if (!name.equals(HttpHeaders.COOKIE) || (cookie == null))
                return enumStr;

            ArrayList<String> list = Collections.list(enumStr);
            list.add(cookie.getName() + "=" + cookie.getValue());
            return Collections.enumeration(list);
        }

        public Enumeration<String> getHeaderNames() {

            Enumeration<String> enumStr = super.getHeaderNames();
            ArrayList<String> list = Collections.list(enumStr);
            if (!list.contains(HttpHeaders.COOKIE) && (cookie != null)) {
                list.add(HttpHeaders.COOKIE);
            }

            return Collections.enumeration(list);
        }

        public Cookie[] getCookies() {

            Cookie[] cookies = super.getCookies();
            if (cookie == null)
                return cookies;

            ArrayList<Cookie> list = (cookies == null) ? new ArrayList<Cookie>()
                    : new ArrayList<Cookie>(Arrays.asList(cookies));

            list.add(cookie);
            cookies = new Cookie[list.size()];
            return list.toArray(cookies);
        }

    }

}