tech.beshu.ror.acl.blocks.rules.impl.JwtAuthSyncRule.java Source code

Java tutorial

Introduction

Here is the source code for tech.beshu.ror.acl.blocks.rules.impl.JwtAuthSyncRule.java

Source

/*
 *    This file is part of ReadonlyREST.
 *
 *    ReadonlyREST is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    ReadonlyREST is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with ReadonlyREST.  If not, see http://www.gnu.org/licenses/
 */

package tech.beshu.ror.acl.blocks.rules.impl;

import com.google.common.base.Strings;
import com.google.common.collect.Sets;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.security.SignatureException;
import org.apache.commons.codec.binary.Base64;
import tech.beshu.ror.acl.blocks.rules.AsyncRule;
import tech.beshu.ror.acl.blocks.rules.RuleExitResult;
import tech.beshu.ror.acl.blocks.rules.phantomtypes.Authentication;
import tech.beshu.ror.acl.definitions.DefinitionsFactory;
import tech.beshu.ror.acl.definitions.externalauthenticationservices.ExternalAuthenticationServiceClient;
import tech.beshu.ror.commons.domain.LoggedUser;
import tech.beshu.ror.commons.shims.es.ESContext;
import tech.beshu.ror.commons.shims.es.LoggerShim;
import tech.beshu.ror.requestcontext.__old_RequestContext;
import tech.beshu.ror.settings.rules.JwtAuthRuleSettings;

import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.spec.X509EncodedKeySpec;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

public class JwtAuthSyncRule extends AsyncRule implements Authentication {

    private final LoggerShim logger;
    private final JwtAuthRuleSettings settings;
    private final ExternalAuthenticationServiceClient httpClient;

    public JwtAuthSyncRule(JwtAuthRuleSettings settings, ESContext context, DefinitionsFactory factory) {
        this.logger = context.logger(getClass());
        this.settings = settings;
        this.httpClient = settings.getExternalValidator().isPresent() ? factory.getClient(settings) : null;
    }

    private static Optional<String> extractToken(String authHeader) {
        if (authHeader.startsWith("Bearer ")) {
            String token = authHeader.substring(7).trim();
            return Optional.ofNullable(Strings.emptyToNull(token));
        }

        return Optional.ofNullable(Strings.emptyToNull(authHeader));
    }

    @Override
    public CompletableFuture<RuleExitResult> match(__old_RequestContext rc) {
        Optional<String> token = Optional.of(rc.getHeaders()).map(m -> m.get(settings.getHeaderName()))
                .flatMap(JwtAuthSyncRule::extractToken);

        /*
          JWT ALGO    FAMILY
          =======================
          NONE        None
            
          HS256       HMAC
          HS384       HMAC
          HS512       HMAC
            
          RS256       RSA
          RS384       RSA
          RS512       RSA
          PS256       RSA
          PS384       RSA
          PS512       RSA
            
          ES256       EC
          ES384       EC
          ES512       EC
        */

        if (!token.isPresent()) {
            logger.debug("Authorization header is missing or does not contain a bearer token");
            return CompletableFuture.completedFuture(NO_MATCH);
        }

        try {

            JwtParser parser = Jwts.parser();

            // Defaulting to HMAC for backward compatibility
            String algoFamily = settings.getAlgo().map(String::toUpperCase).orElse("HMAC");
            if (settings.getKey() == null) {
                algoFamily = "NONE";
            }

            if (!"NONE".equals(algoFamily)) {

                if ("RSA".equals(algoFamily)) {
                    try {
                        byte[] keyBytes = Base64.decodeBase64(settings.getKey());
                        KeyFactory kf = KeyFactory.getInstance("RSA");
                        PublicKey pubKey = kf.generatePublic(new X509EncodedKeySpec(keyBytes));
                        parser.setSigningKey(pubKey);
                    } catch (GeneralSecurityException gso) {
                        throw new RuntimeException(gso);
                    }
                }

                else if ("EC".equals(algoFamily)) {
                    try {
                        byte[] keyBytes = Base64.decodeBase64(settings.getKey());
                        KeyFactory kf = KeyFactory.getInstance("EC");
                        PublicKey pubKey = kf.generatePublic(new X509EncodedKeySpec(keyBytes));
                        parser.setSigningKey(pubKey);
                    } catch (GeneralSecurityException gso) {
                        throw new RuntimeException(gso);
                    }
                }

                else if ("HMAC".equals(algoFamily)) {
                    parser.setSigningKey(settings.getKey());
                } else {
                    throw new RuntimeException("unrecognised algorithm family " + algoFamily
                            + ". Should be either of: HMAC, EC, RSA, NONE");
                }
            }

            Claims jws;
            if (settings.getKey() != null) {
                jws = parser.parseClaimsJws(token.get()).getBody();
            } else {
                String[] ar = token.get().split("\\.");
                if (ar.length < 2) {
                    // token is not a valid JWT
                    return CompletableFuture.completedFuture(NO_MATCH);
                }
                String tokenNoSig = ar[0] + "." + ar[1] + ".";
                jws = parser.parseClaimsJwt(tokenNoSig).getBody();
            }

            Claims finalJws = jws;
            Optional<String> user = settings.getUserClaim().map(claim -> finalJws.get(claim, String.class));
            if (settings.getUserClaim().isPresent())
                if (!user.isPresent()) {
                    return CompletableFuture.completedFuture(NO_MATCH);
                } else {
                    rc.setLoggedInUser(new LoggedUser(user.get()));
                }

            Optional<Set<String>> roles = this.extractRoles(jws);
            if (settings.getRolesClaim().isPresent() && !roles.isPresent()) {
                return CompletableFuture.completedFuture(NO_MATCH);
            }
            if (!settings.getRoles().isEmpty()) {
                if (!roles.isPresent()) {
                    return CompletableFuture.completedFuture(NO_MATCH);
                } else {
                    Set<String> r = roles.get();
                    if (r.isEmpty() || Sets.intersection(r, settings.getRoles()).isEmpty())
                        return CompletableFuture.completedFuture(NO_MATCH);
                }
            }

            if (settings.getExternalValidator().isPresent()) {
                return httpClient.authenticate("x", token.get()).thenApply(resp -> resp ? MATCH : NO_MATCH);
            }
            return CompletableFuture.completedFuture(MATCH);

        } catch (ExpiredJwtException | UnsupportedJwtException | MalformedJwtException | SignatureException e) {
            return CompletableFuture.completedFuture(NO_MATCH);
        }
    }

    @Override
    public String getKey() {
        return settings.getName();
    }

    @SuppressWarnings("unchecked")
    private Optional<Set<String>> extractRoles(Claims jws) {
        // Get claim roles
        Optional<Object> rolesObj = settings.getRolesClaim().map(claim -> {
            String[] path = claim.split("[.]");
            if (path.length < 2)
                return jws.get(claim, Object.class);
            else {
                // Ok we need to parse all sub sequent path
                Object value = jws.get(path[0], Object.class);
                int i = 1;
                while (i < path.length && value != null && value instanceof Map<?, ?>) {
                    value = ((Map<String, Object>) value).get(path[i]);
                    i++;
                }
                return value;
            }
        });

        // Casting
        return rolesObj.flatMap(value -> {
            Set<String> set = new HashSet<>();

            if (value instanceof Collection<?>) {
                set.addAll((Collection<String>) value);
            } else if (value instanceof String) {
                set.add((String) value);
            }
            if (set.isEmpty())
                return Optional.empty();
            return Optional.of(set);
        });
    }
}