edu.byu.wso2.apim.extensions.JWTDecoder.java Source code

Java tutorial

Introduction

Here is the source code for edu.byu.wso2.apim.extensions.JWTDecoder.java

Source

/*
*  Copyright (c) 2005-2010, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
*
*  WSO2 Inc. licenses this file to you under the Apache License,
*  Version 2.0 (the "License"); you may not use this file except
*  in compliance with the License.
*  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package edu.byu.wso2.apim.extensions;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.axiom.util.base64.Base64Utils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.oltu.oauth2.jwt.JWTException;
import org.apache.oltu.oauth2.jwt.JWTProcessor;
import org.apache.synapse.ManagedLifecycle;
import org.apache.synapse.MessageContext;
import org.apache.synapse.SynapseException;
import org.apache.synapse.SynapseLog;
import org.apache.synapse.core.SynapseEnvironment;
import org.apache.synapse.core.axis2.Axis2MessageContext;
import org.apache.synapse.mediators.AbstractMediator;
import org.wso2.carbon.context.CarbonContext;
import org.wso2.carbon.core.util.KeyStoreManager;

/**
 * Custom mediator to extract information from the JWT send from the WSO2 API Manager,
 * and make those properties available in the Synapse context.
 */

public class JWTDecoder extends AbstractMediator implements ManagedLifecycle {

    private static Log log = LogFactory.getLog(JWTDecoder.class);

    private final String CLAIM_URI = "http://wso2.org/claims/";
    private final String SCIM_CLAIM_URI = "urn:scim:schemas:core:1.0:";
    private final String BYU_CLAIM_URI = "http://byu.edu/claims/";

    private KeyStore keyStore;

    public void init(SynapseEnvironment synapseEnvironment) {
        if (log.isInfoEnabled()) {
            log.info("Initializing JWTDecoder Mediator");
        }

        //Todo: determine which keystore file we should use.
        /*       String keyStoreFile = "";
               String password = "";
            
               try {
        keyStore = KeyStore.getInstance("JKS");
               } catch (KeyStoreException e) {
        //throw new Exception("Unable to get JKS KeyStore instance");
               }
               char[] storePass = password.toCharArray();
            
               // load the key store from file system
               FileInputStream fileInputStream = null;
               try {
        fileInputStream = new FileInputStream(keyStoreFile);
        keyStore.load(fileInputStream, storePass);
        fileInputStream.close();
               } catch (FileNotFoundException e) {
        if (log.isErrorEnabled()) {
            log.error("Error loading keystore", e);
        }
               } catch (NoSuchAlgorithmException e) {
        if (log.isErrorEnabled()) {
            log.error("Error loading keystore", e);
        }
               } catch (CertificateException e) {
        if (log.isErrorEnabled()) {
            log.error("Error loading keystore", e);
        }
               } catch (IOException e) {
        if (log.isErrorEnabled()) {
            log.error("Error loading keystore", e);
        }
               }
        */ }

    public boolean mediate(MessageContext synapseContext) {
        SynapseLog synLog = getLog(synapseContext);

        if (synLog.isTraceOrDebugEnabled()) {
            synLog.traceOrDebug("Start : JWTDecoder mediator");
            if (synLog.isTraceTraceEnabled()) {
                synLog.traceTrace("Message : " + synapseContext.getEnvelope());
            }
        }

        // Extract the HTTP headers and then extract the JWT from the HTTP Header map
        org.apache.axis2.context.MessageContext axis2MessageContext = ((Axis2MessageContext) synapseContext)
                .getAxis2MessageContext();
        Object headerObj = axis2MessageContext
                .getProperty(org.apache.axis2.context.MessageContext.TRANSPORT_HEADERS);
        @SuppressWarnings("unchecked")
        Map<String, Object> headers = (Map<String, Object>) headerObj;
        String jwt_assertion = (String) headers.get("x-jwt-assertion");

        // Incoming request does not contain the JWT assertion
        if (jwt_assertion == null || jwt_assertion == "") {
            // Since this is an unauthorized request, send the response back to client with 401 - Unauthorized error
            synapseContext.setTo(null);
            synapseContext.setResponse(true);
            axis2MessageContext.setProperty("HTTP_SC", "401");
            // Log the authentication failure 
            String err = "JWT assertion not found in the message header";
            handleException(err, synapseContext);
            return false;
        }

        //        boolean isSignatureVerified = verifySignature(jwt_assertion, synapseContext);

        try {
            //            if (isSignatureVerified) {
            // Process the JWT, extract the values and set them to the Synapse environment
            if (log.isDebugEnabled()) {
                log.debug("JWT assertion is : " + jwt_assertion);
            }
            JWTProcessor processor = new JWTProcessor().process(jwt_assertion);
            Map<String, Object> claims = processor.getPayloadClaims();
            for (Map.Entry<String, Object> claimEntry : claims.entrySet()) {
                // Extract the claims and set it in Synapse context
                if (claimEntry.getKey().startsWith(CLAIM_URI)) {
                    String tempPropName = claimEntry.getKey().split(CLAIM_URI)[1];
                    synapseContext.setProperty(tempPropName, claimEntry.getValue());
                    if (log.isDebugEnabled()) {
                        log.debug("Getting claim :" + tempPropName + " , " + claimEntry.getValue());
                    }
                } else if (claimEntry.getKey().startsWith(SCIM_CLAIM_URI)) {
                    String tempPropName = claimEntry.getKey().split(SCIM_CLAIM_URI)[1];
                    if (tempPropName.contains(".")) {
                        tempPropName = tempPropName.split("\\.")[1];
                    }
                    synapseContext.setProperty(tempPropName, claimEntry.getValue());
                    if (log.isDebugEnabled()) {
                        log.debug("Getting claim :" + tempPropName + " , " + claimEntry.getValue());
                    }
                } else if (claimEntry.getKey().startsWith(BYU_CLAIM_URI)) {
                    String tempPropName = claimEntry.getKey().split(BYU_CLAIM_URI)[1];
                    if (tempPropName.contains(".")) {
                        tempPropName = tempPropName.split("\\.")[1];
                    }
                    synapseContext.setProperty(tempPropName, claimEntry.getValue());
                    if (log.isDebugEnabled()) {
                        log.debug("Getting claim :" + tempPropName + " , " + claimEntry.getValue());
                    }
                }
            }
            //            } else {
            //                return false;
            //            }
        } catch (JWTException e) {
            log.error(e.getMessage(), e);
            throw new SynapseException(e.getMessage(), e);
        }

        if (synLog.isTraceOrDebugEnabled()) {
            synLog.traceOrDebug("End : JWTDecoder mediator");
        }

        return true;
    }

    private boolean verifySignature(String jwt_assertion, MessageContext synapseContext) {
        boolean isVerified = false;
        String[] split_string = jwt_assertion.split("\\.");
        String base64EncodedHeader = split_string[0];
        String base64EncodedBody = split_string[1];
        String base64EncodedSignature = split_string[2];

        String decodedHeader = new String(Base64Utils.decode(base64EncodedHeader));
        byte[] decodedSignature = Base64Utils.decode(base64EncodedSignature);
        Pattern pattern = Pattern.compile("^[^:]*:[^:]*:[^:]*:\"(.+)\"}$");
        Matcher matcher = pattern.matcher(decodedHeader);
        String base64EncodedCertThumb = null;
        if (matcher.find()) {
            base64EncodedCertThumb = matcher.group(1);
        }
        byte[] decodedCertThumb = Base64Utils.decode(base64EncodedCertThumb);

        Certificate publicCert = null;

        publicCert = getSuperTenantPublicKey(decodedCertThumb, synapseContext);
        try {
            if (publicCert != null) {
                isVerified = verifySignature(publicCert, decodedSignature, base64EncodedHeader, base64EncodedBody,
                        base64EncodedSignature);
            } else if (!isVerified) {
                publicCert = getTenantPublicKey(decodedCertThumb, synapseContext);
                if (publicCert != null) {
                    isVerified = verifySignature(publicCert, decodedSignature, base64EncodedHeader,
                            base64EncodedBody, base64EncodedSignature);
                } else {
                    throw new Exception("Couldn't find a public certificate to verify signature");
                }

            }

        } catch (Exception e) {
            handleSigVerificationException(e, synapseContext);
        }
        return isVerified;
    }

    private Certificate getSuperTenantPublicKey(byte[] decodedCertThumb, MessageContext synapseContext) {
        String alias = getAliasForX509CertThumb(keyStore, decodedCertThumb, synapseContext);
        if (alias != null) {
            // get the certificate associated with the given alias from
            // default keystore
            try {
                return keyStore.getCertificate(alias);
            } catch (KeyStoreException e) {
                if (log.isErrorEnabled()) {
                    log.error("Error when getting server public certificate: ", e);
                }
            }
        }
        return null;
    }

    private Certificate getTenantPublicKey(byte[] decodedCertThumb, MessageContext synapseContext) {
        SynapseLog synLog = getLog(synapseContext);

        int tenantId = CarbonContext.getThreadLocalCarbonContext().getTenantId();
        String tenantDomain = CarbonContext.getThreadLocalCarbonContext().getTenantDomain();

        if (synLog.isTraceOrDebugEnabled()) {
            synLog.traceOrDebug("Tenant Domain: " + tenantDomain);
        }

        KeyStore tenantKeyStore = null;
        KeyStoreManager tenantKSM = KeyStoreManager.getInstance(tenantId);
        String ksName = tenantDomain.trim().replace(".", "-");
        String jksName = ksName + ".jks";
        try {
            tenantKeyStore = tenantKSM.getKeyStore(jksName);
        } catch (Exception e) {
            if (log.isErrorEnabled()) {
                log.error("Error getting keystore for " + tenantDomain, e);
            }
        }
        if (tenantKeyStore != null) {
            String alias = getAliasForX509CertThumb(tenantKeyStore, decodedCertThumb, synapseContext);
            if (alias != null) {
                // get the certificate associated with the given alias
                // from
                // tenant's keystore
                try {
                    return tenantKeyStore.getCertificate(alias);
                } catch (KeyStoreException e) {
                    if (log.isErrorEnabled()) {
                        log.error("Error when getting tenants public certificate: " + tenantDomain, e);
                    }
                }
            }
        }

        return null;
    }

    private boolean verifySignature(Certificate publicCert, byte[] decodedSignature, String base64EncodedHeader,
            String base64EncodedBody, String base64EncodedSignature)
            throws NoSuchAlgorithmException, InvalidKeyException, SignatureException {
        // create signature instance with signature algorithm and public cert,
        // to verify the signature.
        Signature verifySig = Signature.getInstance("SHA256withRSA");
        // init
        verifySig.initVerify(publicCert);
        // update signature with signature data.
        verifySig.update((base64EncodedHeader + "." + base64EncodedBody).getBytes());
        // do the verification
        return verifySig.verify(decodedSignature);
    }

    private String getAliasForX509CertThumb(KeyStore keyStore, byte[] thumb, MessageContext synapseContext) {
        SynapseLog synLog = getLog(synapseContext);
        Certificate cert = null;
        MessageDigest sha = null;

        try {
            sha = MessageDigest.getInstance("SHA-1");
        } catch (NoSuchAlgorithmException e) {
            handleSigVerificationException(e, synapseContext);
        }
        try {
            for (Enumeration<String> e = keyStore.aliases(); e.hasMoreElements();) {
                String alias = e.nextElement();
                Certificate[] certs = keyStore.getCertificateChain(alias);
                if (certs == null || certs.length == 0) {
                    // no cert chain, so lets check if getCertificate gives us a result.
                    cert = keyStore.getCertificate(alias);
                    if (cert == null) {
                        return null;
                    }
                } else {
                    cert = certs[0];
                }
                if (!(cert instanceof X509Certificate)) {
                    continue;
                }
                sha.reset();
                try {
                    sha.update(cert.getEncoded());
                } catch (CertificateEncodingException e1) {
                    //throw new Exception("Error encoding certificate");
                }
                byte[] data = sha.digest();
                if (new String(thumb).equals(hexify(data))) {
                    if (synLog.isTraceOrDebugEnabled()) {
                        synLog.traceOrDebug("Found matching alias: " + alias);
                    }
                    return alias;
                }
            }
        } catch (KeyStoreException e) {
            if (log.isErrorEnabled()) {
                log.error("Error getting alias from keystore", e);
            }
        }
        return null;
    }

    private String hexify(byte bytes[]) {
        char[] hexDigits = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' };

        StringBuffer buf = new StringBuffer(bytes.length * 2);

        for (int i = 0; i < bytes.length; ++i) {
            buf.append(hexDigits[(bytes[i] & 0xf0) >> 4]);
            buf.append(hexDigits[bytes[i] & 0x0f]);
        }

        return buf.toString();
    }

    private void handleSigVerificationException(Exception e, MessageContext synapseContext) {
        synapseContext.setTo(null);
        synapseContext.setResponse(true);
        org.apache.axis2.context.MessageContext axis2MessageContext = ((Axis2MessageContext) synapseContext)
                .getAxis2MessageContext();
        axis2MessageContext.setProperty("HTTP_SC", "401");
        String err = e.getMessage();
        handleException(err, synapseContext);
    }

    public void destroy() {
        if (log.isInfoEnabled()) {
            log.info("Destroying JWTDecoder Mediator");
        }
    }
}