nl.b3p.viewer.util.IPAuthenticationFilter.java Source code

Java tutorial

Introduction

Here is the source code for nl.b3p.viewer.util.IPAuthenticationFilter.java

Source

/*
 * Copyright (C) 2016 B3Partners B.V.
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package nl.b3p.viewer.util;

import java.io.IOException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.persistence.EntityManager;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpSession;
import nl.b3p.viewer.config.security.User;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.hibernate.Hibernate;
import org.stripesstuff.stripersist.Stripersist;

/**
 *
 * @author Meine Toonen meinetoonen@b3partners.nl
 */
public class IPAuthenticationFilter implements Filter {

    private static final Log log = LogFactory.getLog(IPAuthenticationFilter.class);

    // The filter configuration object we are associated with.  If
    // this value is null, this filter instance is not currently
    // configured. 
    private FilterConfig filterConfig = null;

    private static final String IP_CHECK = IPAuthenticationFilter.class + "_IP_CHECK";
    private static final String USER_CHECK = IPAuthenticationFilter.class + "_USER_CHECK";
    private static final String TIME_USER_CHECKED = IPAuthenticationFilter.class + "_TIME_USER_CHECKED";

    private static final int MAX_TIME_USER_CACHE = 20000;

    public IPAuthenticationFilter() {
    }

    /**
     *
     * @param r The servlet request we are processing
     * @param response The servlet response we are creating
     * @param chain The filter chain we are processing
     *
     * @exception IOException if an input/output error occurs
     * @exception ServletException if a servlet error occurs
     */
    @Override
    public void doFilter(ServletRequest r, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) r;
        HttpSession session = request.getSession();
        if (request.getUserPrincipal() != null) {
            chain.doFilter(request, response);
        } else {
            User u = null;
            if ((session.getAttribute(IP_CHECK) == null && session.getAttribute(USER_CHECK) == null)
                    || isCacheValid(session)) {

                String ipAddress = getIp(request);
                session.setAttribute(IP_CHECK, ipAddress);
                Stripersist.requestInit();

                EntityManager em = Stripersist.getEntityManager();
                List<User> users = em.createQuery("from User", User.class).getResultList();
                List<User> possibleUsers = new ArrayList<User>();

                for (User user : users) {
                    if (checkValidIpAddress(request, user)) {
                        possibleUsers.add(user);
                    }
                }

                if (possibleUsers.isEmpty()) {
                    log.debug("No eligible users found for ip");
                } else if (possibleUsers.size() == 1) {
                    u = possibleUsers.get(0);
                    u.setAuthenticatedByIp(true);
                    Hibernate.initialize(u.getGroups());
                    session.setAttribute(IP_CHECK, ipAddress);
                    session.setAttribute(USER_CHECK, u);
                    session.setAttribute(TIME_USER_CHECKED, System.currentTimeMillis());
                } else {
                    log.debug("Too many eligible users found for ip.");
                }
                Stripersist.requestComplete();
            } else {
                u = (User) session.getAttribute(USER_CHECK);
            }
            final User user = u;

            RequestWrapper wrappedRequest = new RequestWrapper((HttpServletRequest) request) {
                @Override
                public Principal getUserPrincipal() {
                    if (user != null) {
                        return user;
                    } else {
                        return super.getUserPrincipal();
                    }
                }

                @Override
                public String getRemoteUser() {
                    if (user != null) {
                        return user.getName();
                    } else {
                        return super.getRemoteUser();
                    }
                }

                @Override
                public boolean isUserInRole(String role) {
                    if (user != null) {
                        return user.checkRole(role);
                    } else {
                        return super.isUserInRole(role);
                    }
                }
            };

            Throwable problem = null;

            try {
                chain.doFilter(wrappedRequest, response);
            } catch (IOException | ServletException t) {
                log.error("Error processing chain", problem);
                throw t;
            }
        }
    }

    /**
     * Return the filter configuration object for this filter.
     * @return filter configuration object
     */
    public FilterConfig getFilterConfig() {
        return (this.filterConfig);
    }

    /**
     * Set the filter configuration object for this filter.
     *
     * @param filterConfig The filter configuration object
     */
    public void setFilterConfig(FilterConfig filterConfig) {
        this.filterConfig = filterConfig;
    }

    /**
     * Destroy method for this filter
     */
    @Override
    public void destroy() {
    }

    /**
     * Init method for this filter.
     *
     * @param filterConfig filter configuration object
     */
    @Override
    public void init(FilterConfig filterConfig) {
        this.filterConfig = filterConfig;
    }

    /**
     * Return a String representation of this object.
     */
    @Override
    public String toString() {
        if (filterConfig == null) {
            return ("IPAuthenticationFilter()");
        }
        StringBuilder sb = new StringBuilder("IPAuthenticationFilter(");
        sb.append(filterConfig);
        sb.append(")");
        return (sb.toString());

    }

    private boolean checkValidIpAddress(HttpServletRequest request, User user) {

        String remoteAddress = getIp(request);

        /* remoteaddress controleren tegen ip adressen van user.
         * Ip ranges mogen ook via een asterisk */
        for (String ipAddress : (Set<String>) user.getIps()) {

            log.debug("Check ip: " + ipAddress + " against: " + remoteAddress);

            if (ipAddress.contains("*")) {
                if (isRemoteAddressWithinIpRange(ipAddress, remoteAddress)) {
                    return true;
                }
            }

            if (ipAddress.equalsIgnoreCase(remoteAddress)) {
                return true;
            }
        }
        log.info("IP address " + remoteAddress + " not allowed for user " + user.getName());

        return false;
    }

    private String getIp(HttpServletRequest request) {
        String remoteAddress = request.getRemoteAddr();
        String forwardedFor = request.getHeader("X-Forwarded-For");
        if (forwardedFor != null) {
            int endIndex = forwardedFor.contains(",") ? forwardedFor.indexOf(",") : forwardedFor.length();
            remoteAddress = forwardedFor.substring(0, endIndex);
        }
        return remoteAddress;
    }

    /* This function should only be called when ip contains an asterisk. This
     is the case when someone has given an ip to a user with an asterisk
     eq. 10.0.0.*  */
    protected boolean isRemoteAddressWithinIpRange(String ip, String remote) {
        if (ip == null || remote == null) {
            return false;
        }

        String[] arrIp = ip.split("\\.");
        String[] arrRemote = remote.split("\\.");

        if (arrIp == null || arrIp.length < 1 || arrRemote == null || arrRemote.length < 1) {
            return false;
        }

        /* kijken of het niet asteriks gedeelte overeenkomt met
         hetzelfde gedeelte uit remote address */
        for (int i = 0; i < arrIp.length; i++) {
            if (!arrIp[i].equalsIgnoreCase("*")) {
                if (!arrIp[i].equalsIgnoreCase(arrRemote[i])) {
                    return false;
                }
            }
        }

        return true;
    }

    private boolean isCacheValid(HttpSession session) {
        if (session == null) {
            return true;
        }
        if (session.getAttribute(TIME_USER_CHECKED) == null) {
            return true;
        }
        long prev = (long) session.getAttribute(TIME_USER_CHECKED);
        long now = System.currentTimeMillis();
        if (now - prev > MAX_TIME_USER_CACHE) {
            return true;
        }

        return false;
    }

    /**
     * This request wrapper class extends the support class
     * HttpServletRequestWrapper, which implements all the methods in the
     * HttpServletRequest interface, as delegations to the wrapped request. You
     * only need to override the methods that you need to change. You can get
     * access to the wrapped request using the method getRequest()
     */
    class RequestWrapper extends HttpServletRequestWrapper {

        public RequestWrapper(HttpServletRequest request) {
            super(request);
        }

        // You might, for example, wish to add a setParameter() method. To do this
        // you must also override the getParameter, getParameterValues, getParameterMap,
        // and getParameterNames methods.
        protected Hashtable localParams = null;

        public void setParameter(String name, String[] values) {

            if (localParams == null) {
                localParams = new Hashtable();
                // Copy the parameters from the underlying request.
                Map wrappedParams = getRequest().getParameterMap();
                Set keySet = wrappedParams.keySet();
                for (Iterator it = keySet.iterator(); it.hasNext();) {
                    Object key = it.next();
                    Object value = wrappedParams.get(key);
                    localParams.put(key, value);
                }
            }
            localParams.put(name, values);
        }

        @Override
        public String getParameter(String name) {
            if (localParams == null) {
                return getRequest().getParameter(name);
            }
            Object val = localParams.get(name);
            if (val instanceof String) {
                return (String) val;
            }
            if (val instanceof String[]) {
                String[] values = (String[]) val;
                return values[0];
            }
            return (val == null ? null : val.toString());
        }

        @Override
        public String[] getParameterValues(String name) {
            if (localParams == null) {
                return getRequest().getParameterValues(name);
            }
            return (String[]) localParams.get(name);
        }

        @Override
        public Enumeration getParameterNames() {
            if (localParams == null) {
                return getRequest().getParameterNames();
            }
            return localParams.keys();
        }

        @Override
        public Map getParameterMap() {
            if (localParams == null) {
                return getRequest().getParameterMap();
            }
            return localParams;
        }
    }
}