io.lavagna.web.security.CSFRFilter.java Source code

Java tutorial

Introduction

Here is the source code for io.lavagna.web.security.CSFRFilter.java

Source

/**
 * This file is part of lavagna.
 *
 * lavagna is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * lavagna is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with lavagna.  If not, see <http://www.gnu.org/licenses/>.
 */
package io.lavagna.web.security;

import static org.apache.commons.lang3.tuple.ImmutablePair.of;

import java.io.IOException;
import java.util.UUID;
import java.util.regex.Pattern;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class CSFRFilter extends AbstractBaseFilter {

    private static final String CSRF_TOKEN_HEADER = "X-CSRF-TOKEN";
    private static final String CSRF_FORM_PARAMETER = "_csrf";
    private static final Pattern CSRF_METHOD_DONT_CHECK = Pattern.compile("^GET|HEAD|OPTIONS$");

    private static final Logger LOG = LogManager.getLogger();

    @Override
    protected void doFilterInternal(HttpServletRequest req, HttpServletResponse resp, FilterChain chain)
            throws IOException, ServletException {

        String token = (String) req.getSession().getAttribute(CSRFToken.CSRF_TOKEN);
        if (token == null) {
            token = UUID.randomUUID().toString();
            req.getSession().setAttribute(CSRFToken.CSRF_TOKEN, token);
        }
        resp.setHeader(CSRF_TOKEN_HEADER, token);

        if (mustCheckCSRF(req)) {
            ImmutablePair<Boolean, ImmutablePair<Integer, String>> res = checkCSRF(req);
            if (!res.left) {
                LOG.info("wrong csrf");
                resp.sendError(res.right.left, res.right.right);
                return;
            }
        }

        //continue...
        chain.doFilter(req, resp);
    }

    private static final Pattern WEBSOCKET_FALLBACK = Pattern.compile("^/api/socket/.*$");

    /**
     * Return true if the filter must check the request
     *
     * @param request
     * @return
     */
    private boolean mustCheckCSRF(HttpServletRequest request) {

        // ignore the websocket fallback...
        if ("POST".equals(request.getMethod()) && WEBSOCKET_FALLBACK
                .matcher(StringUtils.removeStart(request.getRequestURI(), request.getContextPath())).matches()) {
            return false;
        }

        return !CSRF_METHOD_DONT_CHECK.matcher(request.getMethod()).matches();
    }

    private static ImmutablePair<Boolean, ImmutablePair<Integer, String>> checkCSRF(HttpServletRequest request)
            throws IOException {
        String expectedToken = (String) request.getSession().getAttribute(CSRFToken.CSRF_TOKEN);
        String token = request.getHeader(CSRF_TOKEN_HEADER);
        if (token == null) {
            token = request.getParameter(CSRF_FORM_PARAMETER);
        }

        if (token == null) {
            return of(false, of(HttpServletResponse.SC_FORBIDDEN, "missing token in header or parameter"));
        }
        if (expectedToken == null) {
            return of(false, of(HttpServletResponse.SC_FORBIDDEN, "missing token from session"));
        }
        if (!safeArrayEquals(token.getBytes("UTF-8"), expectedToken.getBytes("UTF-8"))) {
            return of(false, of(HttpServletResponse.SC_FORBIDDEN, "token is not equal to expected"));
        }

        return of(true, null);
    }

    // ------------------------------------------------------------------------
    // this function has been imported from KeyCzar.

    /*
     * Copyright 2008 Google Inc.
     * 
     * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
     * with the License. You may obtain a copy of the License at
     * 
     * http://www.apache.org/licenses/LICENSE-2.0
     * 
     * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
     * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for
     * the specific language governing permissions and limitations under the License.
     */

    /**
     * An array comparison that is safe from timing attacks. If two arrays are of equal length, this code will always
     * check all elements, rather than exiting once it encounters a differing byte.
     * 
     * @param a1
     *            An array to compare
     * @param a2
     *            Another array to compare
     * @return True if these arrays are both null or if they have equal length and equal bytes in all elements
     */
    private static boolean safeArrayEquals(byte[] a1, byte[] a2) {
        if (a1 == null || a2 == null) {
            return a1 == a2;
        }
        if (a1.length != a2.length) {
            return false;
        }
        byte result = 0;
        for (int i = 0; i < a1.length; i++) {
            result |= a1[i] ^ a2[i];
        }
        return result == 0;
    }
}