org.wso2.carbon.identity.application.authenticator.samlsso.util.SSOUtils.java Source code

Java tutorial

Introduction

Here is the source code for org.wso2.carbon.identity.application.authenticator.samlsso.util.SSOUtils.java

Source

/*
 * Copyright (c) 2015, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
 *
 * WSO2 Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"); you may not use this file except
 * in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.wso2.carbon.identity.application.authenticator.samlsso.util;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.xml.security.c14n.Canonicalizer;
import org.apache.xml.security.signature.XMLSignature;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.LogoutRequest;
import org.opensaml.saml2.core.RequestAbstractType;
import org.opensaml.xml.XMLObject;
import org.opensaml.xml.XMLObjectBuilder;
import org.opensaml.xml.io.Marshaller;
import org.opensaml.xml.io.MarshallerFactory;
import org.opensaml.xml.security.SigningUtil;
import org.opensaml.xml.security.x509.X509Credential;
import org.opensaml.xml.signature.KeyInfo;
import org.opensaml.xml.signature.Signature;
import org.opensaml.xml.signature.Signer;
import org.opensaml.xml.signature.X509Data;
import org.opensaml.xml.util.Base64;
import org.w3c.dom.Element;
import org.w3c.dom.bootstrap.DOMImplementationRegistry;
import org.w3c.dom.ls.DOMImplementationLS;
import org.w3c.dom.ls.LSOutput;
import org.w3c.dom.ls.LSSerializer;
import org.wso2.carbon.identity.application.authenticator.samlsso.exception.SAMLSSOException;
import org.wso2.carbon.identity.application.authenticator.samlsso.manager.X509CredentialImpl;
import org.wso2.carbon.identity.application.common.util.IdentityApplicationConstants;
import org.wso2.carbon.identity.application.common.util.IdentityApplicationManagementUtil;

import javax.xml.namespace.QName;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.security.cert.CertificateEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.zip.DataFormatException;
import java.util.zip.Inflater;
import java.util.zip.InflaterInputStream;
import org.opensaml.xml.io.MarshallingException;
import org.opensaml.xml.signature.SignatureException;

import org.opensaml.common.impl.SAMLObjectContentReference;
import java.io.UnsupportedEncodingException;

public class SSOUtils {
    private SSOUtils() {

    }

    private static Log log = LogFactory.getLog(SSOUtils.class);

    /**
     * Generates a unique Id for Authentication Requests
     *
     * @return generated unique ID
     */

    public static String createID() {

        byte[] bytes = new byte[20]; // 160 bit

        new Random().nextBytes(bytes);

        char[] charMapping = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p' };

        char[] chars = new char[40];

        for (int i = 0; i < bytes.length; i++) {
            int left = (bytes[i] >> 4) & 0x0f;
            int right = bytes[i] & 0x0f;
            chars[i * 2] = charMapping[left];
            chars[i * 2 + 1] = charMapping[right];
        }

        return String.valueOf(chars);
    }

    /**
     * Sign the SAML Request message
     * 
     * @param request
     * @param signatureAlgorithm
     * @param digestAlgorithm
     * @param includeCert
     * @param credential
     * @return
     * @throws SAMLSSOException
     */
    public static void setSignature(RequestAbstractType request, String signatureAlgorithm, String digestAlgorithm,
            boolean includeCert, X509Credential x509Credential) throws SAMLSSOException {

        if (request == null) {
            throw new IllegalArgumentException("Request cannot be null");
        }
        if (x509Credential == null) {
            throw new IllegalArgumentException("X509Credential cannot be null");
        }
        if (x509Credential.getEntityCertificate() == null) {
            throw new SAMLSSOException("IdP certificate is needed for AuthnRequest signing in POST binding");
        }
        //TODO use StringUtils.isBlank
        if (StringUtils.isEmpty(signatureAlgorithm)) {
            signatureAlgorithm = IdentityApplicationManagementUtil.getXMLSignatureAlgorithms()
                    .get(IdentityApplicationConstants.XML.SignatureAlgorithm.RSA_SHA1);
        }
        if (StringUtils.isEmpty(digestAlgorithm)) {
            digestAlgorithm = IdentityApplicationManagementUtil.getXMLDigestAlgorithms()
                    .get(IdentityApplicationConstants.XML.DigestAlgorithm.SHA1);
        }

        Signature signature = (Signature) buildXMLObject(Signature.DEFAULT_ELEMENT_NAME);
        signature.setSigningCredential(x509Credential);
        signature.setSignatureAlgorithm(signatureAlgorithm);
        signature.setCanonicalizationAlgorithm(Canonicalizer.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);

        if (includeCert) {
            KeyInfo keyInfo = (KeyInfo) buildXMLObject(KeyInfo.DEFAULT_ELEMENT_NAME);
            X509Data data = (X509Data) buildXMLObject(X509Data.DEFAULT_ELEMENT_NAME);
            org.opensaml.xml.signature.X509Certificate cert = (org.opensaml.xml.signature.X509Certificate) buildXMLObject(
                    org.opensaml.xml.signature.X509Certificate.DEFAULT_ELEMENT_NAME);
            String value = null;
            try {
                value = org.apache.xml.security.utils.Base64
                        .encode(x509Credential.getEntityCertificate().getEncoded());
            } catch (CertificateEncodingException e) {
                throw new SAMLSSOException("Error getting the certificate to include in the signature", e);
            }
            cert.setValue(value);
            data.getX509Certificates().add(cert);
            keyInfo.getX509Datas().add(data);
            signature.setKeyInfo(keyInfo);
        }

        request.setSignature(signature);
        ((SAMLObjectContentReference) signature.getContentReferences().get(0)).setDigestAlgorithm(digestAlgorithm);

        List<Signature> signatureList = new ArrayList<Signature>();
        signatureList.add(signature);

        // Marshall and Sign
        MarshallerFactory marshallerFactory = org.opensaml.xml.Configuration.getMarshallerFactory();
        Marshaller marshaller = marshallerFactory.getMarshaller(request);
        try {
            marshaller.marshall(request);
        } catch (MarshallingException e) {
            throw new SAMLSSOException("Error while marshalling the SAML Request for signing", e);
        }

        org.apache.xml.security.Init.init();
        try {
            Signer.signObjects(signatureList);
        } catch (SignatureException e) {
            throw new SAMLSSOException("Error while signing the SAML Request", e);
        }
    }

    public static void addSignatureToHTTPQueryString(StringBuilder httpQueryString, String signatureAlgorithmURI,
            X509Credential credential) throws SAMLSSOException {
        try {
            httpQueryString.append("&SigAlg=");
            httpQueryString.append(URLEncoder.encode(signatureAlgorithmURI, "UTF-8").trim());

            byte[] rawSignature = SigningUtil.signWithURI(credential, signatureAlgorithmURI,
                    httpQueryString.toString().getBytes("UTF-8"));

            String base64Signature = Base64.encodeBytes(rawSignature, Base64.DONT_BREAK_LINES);

            if (log.isDebugEnabled()) {
                log.debug("Generated digital signature value (base64-encoded) {} " + base64Signature);
            }

            httpQueryString.append("&Signature=" + URLEncoder.encode(base64Signature, "UTF-8").trim());

        } catch (org.opensaml.xml.security.SecurityException e) {
            throw new SAMLSSOException("Unable to sign query string", e);
        } catch (UnsupportedEncodingException e) {
            // UTF-8 encoding is required to be supported by all JVMs
            throw new SAMLSSOException("Error while adding signature to HTTP query string", e);
        }
    }

    /**
     * Builds SAML Elements
     *
     * @param objectQName
     * @return
     * @throws SAMLSSOException
     */
    private static XMLObject buildXMLObject(QName objectQName) throws SAMLSSOException {
        XMLObjectBuilder builder = org.opensaml.xml.Configuration.getBuilderFactory().getBuilder(objectQName);
        if (builder == null) {
            throw new SAMLSSOException("Unable to retrieve builder for object QName " + objectQName);
        }
        return builder.buildObject(objectQName.getNamespaceURI(), objectQName.getLocalPart(),
                objectQName.getPrefix());
    }

    /**
     * Decoding and deflating the encoded AuthReq
     *
     * @param encodedStr encoded AuthReq
     * @return decoded AuthReq
     */
    public static String decode(String encodedStr) throws SAMLSSOException {
        try {
            if (log.isDebugEnabled()) {
                log.debug(" >> encoded string in the SSOUtils/decode : " + encodedStr);
            }
            org.apache.commons.codec.binary.Base64 base64Decoder = new org.apache.commons.codec.binary.Base64();
            byte[] xmlBytes = encodedStr.getBytes("UTF-8");
            byte[] base64DecodedByteArray = base64Decoder.decode(xmlBytes);

            try {
                //TODO if the request came in POST, inflating is wrong
                Inflater inflater = new Inflater(true);
                inflater.setInput(base64DecodedByteArray);
                byte[] xmlMessageBytes = new byte[5000];
                int resultLength = inflater.inflate(xmlMessageBytes);

                if (inflater.getRemaining() > 0) {
                    throw new RuntimeException("didn't allocate enough space to hold " + "decompressed data");
                }

                inflater.end();
                String decodedString = new String(xmlMessageBytes, 0, resultLength, "UTF-8");
                if (log.isDebugEnabled()) {
                    log.debug("Request message " + decodedString);
                }
                return decodedString;

            } catch (DataFormatException e) {
                ByteArrayInputStream bais = new ByteArrayInputStream(base64DecodedByteArray);
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                InflaterInputStream iis = new InflaterInputStream(bais);
                byte[] buf = new byte[1024];
                int count = iis.read(buf);
                while (count != -1) {
                    baos.write(buf, 0, count);
                    count = iis.read(buf);
                }
                iis.close();
                String decodedStr = new String(baos.toByteArray(), Charset.forName("UTF-8"));
                if (log.isDebugEnabled()) {
                    log.debug("Request message " + decodedStr);
                }
                return decodedStr;
            }
        } catch (IOException e) {
            throw new SAMLSSOException("Error when decoding the SAML Request.", e);
        }

    }

    public static String decodeForPost(String encodedStr) throws SAMLSSOException {
        try {
            org.apache.commons.codec.binary.Base64 base64Decoder = new org.apache.commons.codec.binary.Base64();
            byte[] xmlBytes = encodedStr.getBytes("UTF-8");
            byte[] base64DecodedByteArray = base64Decoder.decode(xmlBytes);

            String decodedString = new String(base64DecodedByteArray, "UTF-8");
            if (log.isDebugEnabled()) {
                log.debug("Request message " + decodedString);
            }
            return decodedString;

        } catch (IOException e) {
            throw new SAMLSSOException("Error when decoding the SAML Request.", e);
        }
    }

    /**
     * Serializing a SAML2 object into a String
     *
     * @param xmlObject object that needs to serialized.
     * @return serialized object
     * @throws SAMLSSOException
     */
    public static String marshall(XMLObject xmlObject) throws SAMLSSOException {
        try {

            System.setProperty("javax.xml.parsers.DocumentBuilderFactory",
                    "org.apache.xerces.jaxp.DocumentBuilderFactoryImpl");

            MarshallerFactory marshallerFactory = org.opensaml.xml.Configuration.getMarshallerFactory();
            Marshaller marshaller = marshallerFactory.getMarshaller(xmlObject);
            Element element = marshaller.marshall(xmlObject);

            ByteArrayOutputStream byteArrayOutputStrm = new ByteArrayOutputStream();
            DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
            DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
            LSSerializer writer = impl.createLSSerializer();
            LSOutput output = impl.createLSOutput();
            output.setByteStream(byteArrayOutputStrm);
            writer.write(element, output);
            return byteArrayOutputStrm.toString();
        } catch (Exception e) {
            log.error("Error Serializing the SAML Response");
            throw new SAMLSSOException("Error Serializing the SAML Response", e);
        }
    }

    /**
     * Encoding the response
     *
     * @param xmlString String to be encoded
     * @return encoded String
     */
    public static String encode(String xmlString) {
        String encodedRequestMessage = Base64.encodeBytes(xmlString.getBytes(), Base64.DONT_BREAK_LINES);
        return encodedRequestMessage.trim();
    }

    public static boolean isAuthnRequestSigned(Map<String, String> properties) {
        if (properties != null) {
            String prop = properties.get(IdentityApplicationConstants.Authenticator.SAML2SSO.IS_AUTHN_REQ_SIGNED);
            if (prop != null) {
                return Boolean.parseBoolean(prop);
            }
        }
        return false;
    }

    public static boolean isLogoutEnabled(Map<String, String> properties) {
        if (properties != null) {
            String prop = properties.get(IdentityApplicationConstants.Authenticator.SAML2SSO.IS_LOGOUT_ENABLED);
            if (prop != null) {
                return Boolean.parseBoolean(prop);
            }
        }
        return false;
    }

    public static boolean isLogoutRequestSigned(Map<String, String> properties) {
        if (properties != null) {
            String prop = properties.get(IdentityApplicationConstants.Authenticator.SAML2SSO.IS_LOGOUT_REQ_SIGNED);
            if (prop != null) {
                return Boolean.parseBoolean(prop);
            }
        }
        return false;
    }

    public static boolean isAuthnResponseSigned(Map<String, String> properties) {
        if (properties != null) {
            String prop = properties.get(IdentityApplicationConstants.Authenticator.SAML2SSO.IS_AUTHN_RESP_SIGNED);
            if (prop != null) {
                return Boolean.parseBoolean(prop);
            }
        }
        return false;
    }

    public static boolean isAssertionSigningEnabled(Map<String, String> properties) {
        if (properties != null) {
            String prop = properties
                    .get(IdentityApplicationConstants.Authenticator.SAML2SSO.IS_ENABLE_ASSERTION_SIGNING);
            if (prop != null) {
                return Boolean.parseBoolean(prop);
            }
        }
        return false;
    }

    public static boolean isAssertionEncryptionEnabled(Map<String, String> properties) {
        if (properties != null) {
            String prop = properties
                    .get(IdentityApplicationConstants.Authenticator.SAML2SSO.IS_ENABLE_ASSERTION_ENCRYPTION);
            if (prop != null) {
                return Boolean.parseBoolean(prop);
            }
        }
        return false;
    }

    public static Map<String, String> getQueryMap(String query) {
        String[] params = query.split("&");
        Map<String, String> map = new HashMap<String, String>();
        for (String param : params) {
            String[] paramSplitArr = param.split("=");
            String name = paramSplitArr[0];
            String value = "";
            if (paramSplitArr.length > 1) {
                value = paramSplitArr[1];
            }
            map.put(name, value);
        }
        return map;
    }
}