Java tutorial
/* * Copyright (c) 2010, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. * * WSO2 Inc. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except * in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.wso2.carbon.identity.sso.saml.util; import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.xerces.impl.Constants; import org.apache.xerces.util.SecurityManager; import org.joda.time.DateTime; import org.opensaml.Configuration; import org.opensaml.DefaultBootstrap; import org.opensaml.saml2.core.Assertion; import org.opensaml.saml2.core.EncryptedAssertion; import org.opensaml.saml2.core.Issuer; import org.opensaml.saml2.core.LogoutRequest; import org.opensaml.saml2.core.LogoutResponse; import org.opensaml.saml2.core.RequestAbstractType; import org.opensaml.saml2.core.Response; import org.opensaml.saml2.core.impl.AuthnRequestImpl; import org.opensaml.saml2.core.impl.IssuerBuilder; import org.opensaml.xml.ConfigurationException; import org.opensaml.xml.XMLObject; import org.opensaml.xml.io.Marshaller; import org.opensaml.xml.io.MarshallerFactory; import org.opensaml.xml.io.Unmarshaller; import org.opensaml.xml.io.UnmarshallerFactory; import org.opensaml.xml.security.SecurityException; import org.opensaml.xml.security.x509.X509Credential; import org.opensaml.xml.signature.SignableXMLObject; import org.opensaml.xml.util.Base64; import org.osgi.framework.BundleContext; import org.osgi.service.http.HttpService; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.bootstrap.DOMImplementationRegistry; import org.w3c.dom.ls.DOMImplementationLS; import org.w3c.dom.ls.LSOutput; import org.w3c.dom.ls.LSSerializer; import org.wso2.carbon.context.PrivilegedCarbonContext; import org.wso2.carbon.context.RegistryType; import org.wso2.carbon.core.util.KeyStoreManager; import org.wso2.carbon.identity.application.common.model.ClaimMapping; import org.wso2.carbon.identity.application.common.model.FederatedAuthenticatorConfig; import org.wso2.carbon.identity.application.common.model.IdentityProvider; import org.wso2.carbon.identity.application.common.model.SAML2SSOFederatedAuthenticatorConfig; import org.wso2.carbon.identity.application.common.util.IdentityApplicationConstants; import org.wso2.carbon.identity.base.IdentityConstants; import org.wso2.carbon.identity.base.IdentityException; import org.wso2.carbon.identity.core.model.SAMLSSOServiceProviderDO; import org.wso2.carbon.identity.core.persistence.IdentityPersistenceManager; import org.wso2.carbon.identity.core.util.IdentityTenantUtil; import org.wso2.carbon.identity.core.util.IdentityUtil; import org.wso2.carbon.identity.sso.saml.SAMLSSOConstants; import org.wso2.carbon.identity.sso.saml.SSOServiceProviderConfigManager; import org.wso2.carbon.identity.sso.saml.builders.DefaultResponseBuilder; import org.wso2.carbon.identity.sso.saml.builders.ErrorResponseBuilder; import org.wso2.carbon.identity.sso.saml.builders.ResponseBuilder; import org.wso2.carbon.identity.sso.saml.builders.X509CredentialImpl; import org.wso2.carbon.identity.sso.saml.builders.assertion.SAMLAssertionBuilder; import org.wso2.carbon.identity.sso.saml.builders.encryption.SSOEncrypter; import org.wso2.carbon.identity.sso.saml.builders.signature.SSOSigner; import org.wso2.carbon.identity.sso.saml.dto.SAMLSSOAuthnReqDTO; import org.wso2.carbon.identity.sso.saml.exception.IdentitySAML2SSOException; import org.wso2.carbon.identity.sso.saml.session.SSOSessionPersistenceManager; import org.wso2.carbon.identity.sso.saml.validators.SAML2HTTPRedirectSignatureValidator; import org.wso2.carbon.idp.mgt.IdentityProviderManagementException; import org.wso2.carbon.idp.mgt.IdentityProviderManager; import org.wso2.carbon.registry.core.Registry; import org.wso2.carbon.registry.core.service.RegistryService; import org.wso2.carbon.registry.core.service.TenantRegistryLoader; import org.wso2.carbon.user.api.UserStoreException; import org.wso2.carbon.user.core.service.RealmService; import org.wso2.carbon.utils.ConfigurationContextService; import org.wso2.carbon.utils.multitenancy.MultitenantConstants; import javax.xml.XMLConstants; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.UnsupportedEncodingException; import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; import java.security.KeyStore; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.zip.DataFormatException; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; import java.util.zip.Inflater; import java.util.zip.InflaterInputStream; public class SAMLSSOUtil { private static final char[] charMapping = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p' }; private static final Set<Character> UNRESERVED_CHARACTERS = new HashSet<>(); private static final ThreadLocal<Boolean> isSaaSApplication = new ThreadLocal<>(); private static final ThreadLocal<String> userTenantDomainThreadLocal = new ThreadLocal<>(); private static final String DefaultAssertionBuilder = "org.wso2.carbon.identity.sso.saml.builders.assertion.DefaultSAMLAssertionBuilder"; private static final String SECURITY_MANAGER_PROPERTY = Constants.XERCES_PROPERTY_PREFIX + Constants.SECURITY_MANAGER_PROPERTY; private static final int ENTITY_EXPANSION_LIMIT = 0; static { for (char c = 'a'; c <= 'z'; c++) UNRESERVED_CHARACTERS.add(Character.valueOf(c)); for (char c = 'A'; c <= 'A'; c++) UNRESERVED_CHARACTERS.add(Character.valueOf(c)); for (char c = '0'; c <= '9'; c++) UNRESERVED_CHARACTERS.add(Character.valueOf(c)); UNRESERVED_CHARACTERS.add(Character.valueOf('-')); UNRESERVED_CHARACTERS.add(Character.valueOf('.')); UNRESERVED_CHARACTERS.add(Character.valueOf('_')); UNRESERVED_CHARACTERS.add(Character.valueOf('~')); } private static Log log = LogFactory.getLog(SAMLSSOUtil.class); private static RegistryService registryService; private static TenantRegistryLoader tenantRegistryLoader; private static BundleContext bundleContext; private static RealmService realmService; private static ConfigurationContextService configCtxService; private static HttpService httpService; private static boolean isBootStrapped = false; private static Random random = new Random(); private static int singleLogoutRetryCount = 5; private static long singleLogoutRetryInterval = 60000; private static String responseBuilderClassName = null; private static SAMLAssertionBuilder samlAssertionBuilder = null; private static SSOEncrypter ssoEncrypter = null; private static SSOSigner ssoSigner = null; private static SAML2HTTPRedirectSignatureValidator samlHTTPRedirectSignatureValidator = null; private static ThreadLocal tenantDomainInThreadLocal = new ThreadLocal(); private SAMLSSOUtil() { } public static boolean isSaaSApplication() { if (isSaaSApplication == null) { // this is the default behavior. return true; } Boolean value = isSaaSApplication.get(); if (value != null) { return value; } return false; } public static void setIsSaaSApplication(boolean isSaaSApp) { isSaaSApplication.set(isSaaSApp); } public static void removeSaaSApplicationThreaLocal() { isSaaSApplication.remove(); } public static String getUserTenantDomain() { if (userTenantDomainThreadLocal == null) { // this is the default behavior. return null; } return userTenantDomainThreadLocal.get(); } public static void setUserTenantDomain(String tenantDomain) throws UserStoreException, IdentityException { tenantDomain = validateTenantDomain(tenantDomain); if (tenantDomain != null) { userTenantDomainThreadLocal.set(tenantDomain); } } public static void removeUserTenantDomainThreaLocal() { userTenantDomainThreadLocal.remove(); } public static BundleContext getBundleContext() { return SAMLSSOUtil.bundleContext; } public static void setBundleContext(BundleContext bundleContext) { SAMLSSOUtil.bundleContext = bundleContext; } public static RegistryService getRegistryService() { return registryService; } public static void setRegistryService(RegistryService registryService) { SAMLSSOUtil.registryService = registryService; } public static TenantRegistryLoader getTenantRegistryLoader() { return tenantRegistryLoader; } public static void setTenantRegistryLoader(TenantRegistryLoader tenantRegistryLoader) { SAMLSSOUtil.tenantRegistryLoader = tenantRegistryLoader; } public static RealmService getRealmService() { return realmService; } public static void setRealmService(RealmService realmService) { SAMLSSOUtil.realmService = realmService; } public static ConfigurationContextService getConfigCtxService() { return configCtxService; } public static void setConfigCtxService(ConfigurationContextService configCtxService) { SAMLSSOUtil.configCtxService = configCtxService; } public static HttpService getHttpService() { return httpService; } public static void setHttpService(HttpService httpService) { SAMLSSOUtil.httpService = httpService; } /** * Constructing the AuthnRequest Object from a String * * @param authReqStr Decoded AuthReq String * @return AuthnRequest Object * @throws org.wso2.carbon.identity.base.IdentityException */ public static XMLObject unmarshall(String authReqStr) throws IdentityException { InputStream inputStream = null; try { doBootstrap(); DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); documentBuilderFactory.setNamespaceAware(true); documentBuilderFactory.setExpandEntityReferences(false); documentBuilderFactory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true); SecurityManager securityManager = new SecurityManager(); securityManager.setEntityExpansionLimit(ENTITY_EXPANSION_LIMIT); documentBuilderFactory.setAttribute(SECURITY_MANAGER_PROPERTY, securityManager); DocumentBuilder docBuilder = documentBuilderFactory.newDocumentBuilder(); docBuilder.setEntityResolver(new CarbonEntityResolver()); inputStream = new ByteArrayInputStream(authReqStr.trim().getBytes(StandardCharsets.UTF_8)); Document document = docBuilder.parse(inputStream); Element element = document.getDocumentElement(); UnmarshallerFactory unmarshallerFactory = Configuration.getUnmarshallerFactory(); Unmarshaller unmarshaller = unmarshallerFactory.getUnmarshaller(element); return unmarshaller.unmarshall(element); } catch (Exception e) { log.error("Error in constructing AuthRequest from the encoded String", e); throw new IdentityException("Error in constructing AuthRequest from the encoded String ", e); } finally { if (inputStream != null) { try { inputStream.close(); } catch (IOException e) { log.error("Error while closing the stream", e); } } } } /** * Serialize the Auth. Request * * @param xmlObject * @return serialized auth. req */ public static String marshall(XMLObject xmlObject) throws IdentityException { ByteArrayOutputStream byteArrayOutputStrm = null; try { doBootstrap(); System.setProperty("javax.xml.parsers.DocumentBuilderFactory", "org.apache.xerces.jaxp.DocumentBuilderFactoryImpl"); MarshallerFactory marshallerFactory = org.opensaml.xml.Configuration.getMarshallerFactory(); Marshaller marshaller = marshallerFactory.getMarshaller(xmlObject); Element element = marshaller.marshall(xmlObject); byteArrayOutputStrm = new ByteArrayOutputStream(); DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance(); DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS"); LSSerializer writer = impl.createLSSerializer(); LSOutput output = impl.createLSOutput(); output.setByteStream(byteArrayOutputStrm); writer.write(element, output); return byteArrayOutputStrm.toString("UTF-8"); } catch (Exception e) { log.error("Error Serializing the SAML Response"); throw new IdentityException("Error Serializing the SAML Response", e); } finally { if (byteArrayOutputStrm != null) { try { byteArrayOutputStrm.close(); } catch (IOException e) { log.error("Error while closing the stream", e); } } } } /** * Encoding the response * * @param xmlString String to be encoded * @return encoded String */ public static String encode(String xmlString) { // Encoding the message String encodedRequestMessage = Base64.encodeBytes(xmlString.getBytes(StandardCharsets.UTF_8), Base64.DONT_BREAK_LINES); return encodedRequestMessage.trim(); } /** * Decoding and deflating the encoded AuthReq * * @param encodedStr encoded AuthReq * @return decoded AuthReq */ public static String decode(String encodedStr) throws IdentityException { try { org.apache.commons.codec.binary.Base64 base64Decoder = new org.apache.commons.codec.binary.Base64(); byte[] xmlBytes = encodedStr.getBytes("UTF-8"); byte[] base64DecodedByteArray = base64Decoder.decode(xmlBytes); try { Inflater inflater = new Inflater(true); inflater.setInput(base64DecodedByteArray); byte[] xmlMessageBytes = new byte[5000]; int resultLength = inflater.inflate(xmlMessageBytes); if (inflater.getRemaining() > 0) { throw new RuntimeException("didn't allocate enough space to hold " + "decompressed data"); } inflater.end(); String decodedString = new String(xmlMessageBytes, 0, resultLength, "UTF-8"); if (log.isDebugEnabled()) { log.debug("Request message " + decodedString); } return decodedString; } catch (DataFormatException e) { ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(base64DecodedByteArray); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); InflaterInputStream iis = new InflaterInputStream(byteArrayInputStream); byte[] buf = new byte[1024]; int count = iis.read(buf); while (count != -1) { byteArrayOutputStream.write(buf, 0, count); count = iis.read(buf); } iis.close(); String decodedStr = new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8); if (log.isDebugEnabled()) { log.debug("Request message " + decodedStr, e); } return decodedStr; } } catch (IOException e) { throw new IdentityException("Error when decoding the SAML Request.", e); } } public static String decodeForPost(String encodedStr) throws IdentityException { try { org.apache.commons.codec.binary.Base64 base64Decoder = new org.apache.commons.codec.binary.Base64(); byte[] xmlBytes = encodedStr.getBytes("UTF-8"); byte[] base64DecodedByteArray = base64Decoder.decode(xmlBytes); String decodedString = new String(base64DecodedByteArray, "UTF-8"); if (log.isDebugEnabled()) { log.debug("Request message " + decodedString); } return decodedString; } catch (IOException e) { throw new IdentityException("Error when decoding the SAML Request.", e); } } /** * Get the Issuer * * @return Issuer */ public static Issuer getIssuer() throws IdentityException { return getIssuerFromTenantDomain(getTenantDomainFromThreadLocal()); } public static Issuer getIssuerFromTenantDomain(String tenantDomain) throws IdentityException { Issuer issuer = new IssuerBuilder().buildObject(); String idPEntityId = null; IdentityProvider identityProvider; int tenantId; if (StringUtils.isEmpty(tenantDomain) || "null".equals(tenantDomain)) { tenantDomain = MultitenantConstants.SUPER_TENANT_DOMAIN_NAME; tenantId = MultitenantConstants.SUPER_TENANT_ID; } else { try { tenantId = SAMLSSOUtil.getRealmService().getTenantManager().getTenantId(tenantDomain); } catch (UserStoreException e) { throw new IdentityException("Error occurred while retrieving tenant id from tenant domain", e); } if (MultitenantConstants.INVALID_TENANT_ID == tenantId) { throw new IdentityException("Invalid tenant domain - '" + tenantDomain + "'"); } } IdentityTenantUtil.initializeRegistry(tenantId, tenantDomain); try { identityProvider = IdentityProviderManager.getInstance().getResidentIdP(tenantDomain); } catch (IdentityProviderManagementException e) { throw new IdentityException( "Error occurred while retrieving Resident Identity Provider information for tenant " + tenantDomain, e); } FederatedAuthenticatorConfig[] authnConfigs = identityProvider.getFederatedAuthenticatorConfigs(); for (FederatedAuthenticatorConfig config : authnConfigs) { if (IdentityApplicationConstants.Authenticator.SAML2SSO.NAME.equals(config.getName())) { SAML2SSOFederatedAuthenticatorConfig samlFedAuthnConfig = new SAML2SSOFederatedAuthenticatorConfig( config); idPEntityId = samlFedAuthnConfig.getIdpEntityId(); } } if (idPEntityId == null) { idPEntityId = IdentityUtil.getProperty(IdentityConstants.ServerConfig.ENTITY_ID); } issuer.setValue(idPEntityId); issuer.setFormat(SAMLSSOConstants.NAME_ID_POLICY_ENTITY); return issuer; } public static void doBootstrap() { if (!isBootStrapped) { try { DefaultBootstrap.bootstrap(); isBootStrapped = true; } catch (ConfigurationException e) { log.error("Error in bootstrapping the OpenSAML2 library", e); } } } /** * Sign the SAML Assertion * * @param response * @param signatureAlgorithm * @param digestAlgorithm * @param cred * @return * @throws IdentityException */ public static Assertion setSignature(Assertion response, String signatureAlgorithm, String digestAlgorithm, X509Credential cred) throws IdentityException { return (Assertion) doSetSignature(response, signatureAlgorithm, digestAlgorithm, cred); } /** * Sign the SAML Response message * * @param response * @param signatureAlgorithm * @param digestAlgorithm * @param cred * @return * @throws IdentityException */ public static Response setSignature(Response response, String signatureAlgorithm, String digestAlgorithm, X509Credential cred) throws IdentityException { return (Response) doSetSignature(response, signatureAlgorithm, digestAlgorithm, cred); } /** * Sign the SAML LogoutResponse message * * @param response * @param signatureAlgorithm * @param digestAlgorithm * @param cred * @return * @throws IdentityException */ public static LogoutResponse setSignature(LogoutResponse response, String signatureAlgorithm, String digestAlgorithm, X509Credential cred) throws IdentityException { return (LogoutResponse) doSetSignature(response, signatureAlgorithm, digestAlgorithm, cred); } /** * Sign SAML Logout Request message * * @param request * @param signatureAlgorithm * @param digestAlgorithm * @param cred * @return * @throws IdentityException */ public static LogoutRequest setSignature(LogoutRequest request, String signatureAlgorithm, String digestAlgorithm, X509Credential cred) throws IdentityException { return (LogoutRequest) doSetSignature(request, signatureAlgorithm, digestAlgorithm, cred); } /** * Generic method to sign SAML Logout Request * * @param request * @param signatureAlgorithm * @param digestAlgorithm * @param cred * @return * @throws IdentityException */ private static SignableXMLObject doSetSignature(SignableXMLObject request, String signatureAlgorithm, String digestAlgorithm, X509Credential cred) throws IdentityException { doBootstrap(); try { synchronized (Runtime.getRuntime().getClass()) { ssoSigner = (SSOSigner) Class.forName(IdentityUtil.getProperty("SSOService.SAMLSSOSigner").trim()) .newInstance(); ssoSigner.init(); } return ssoSigner.setSignature(request, signatureAlgorithm, digestAlgorithm, cred); } catch (ClassNotFoundException e) { throw new IdentityException("Class not found: " + IdentityUtil.getProperty("SSOService.SAMLSSOSigner"), e); } catch (InstantiationException e) { throw new IdentityException( "Error while instantiating class: " + IdentityUtil.getProperty("SSOService.SAMLSSOSigner"), e); } catch (IllegalAccessException e) { throw new IdentityException( "Illegal access to class: " + IdentityUtil.getProperty("SSOService.SAMLSSOSigner"), e); } catch (Exception e) { throw new IdentityException("Error while signing the XML object.", e); } } public static EncryptedAssertion setEncryptedAssertion(Assertion assertion, String encryptionAlgorithm, String alias, String domainName) throws IdentityException { doBootstrap(); try { X509Credential cred = SAMLSSOUtil.getX509CredentialImplForTenant(domainName, alias); synchronized (Runtime.getRuntime().getClass()) { ssoEncrypter = (SSOEncrypter) Class .forName(IdentityUtil.getProperty("SSOService.SAMLSSOEncrypter").trim()).newInstance(); ssoEncrypter.init(); } return ssoEncrypter.doEncryptedAssertion(assertion, cred, alias, encryptionAlgorithm); } catch (ClassNotFoundException e) { throw new IdentityException( "Class not found: " + IdentityUtil.getProperty("SSOService.SAMLSSOEncrypter"), e); } catch (InstantiationException e) { throw new IdentityException( "Error while instantiating class: " + IdentityUtil.getProperty("SSOService.SAMLSSOEncrypter"), e); } catch (IllegalAccessException e) { throw new IdentityException( "Illegal access to class: " + IdentityUtil.getProperty("SSOService.SAMLSSOEncrypter"), e); } catch (Exception e) { throw new IdentityException("Error while signing the SAML Response message.", e); } } public static Assertion buildSAMLAssertion(SAMLSSOAuthnReqDTO authReqDTO, DateTime notOnOrAfter, String sessionId) throws IdentityException { doBootstrap(); String assertionBuilderClass = null; try { assertionBuilderClass = IdentityUtil.getProperty("SSOService.SAMLSSOAssertionBuilder").trim(); if (StringUtils.isBlank(assertionBuilderClass)) { assertionBuilderClass = DefaultAssertionBuilder; } } catch (Exception e) { if (log.isDebugEnabled()) { log.debug("SAMLSSOAssertionBuilder configuration is set to default builder ", e); } assertionBuilderClass = DefaultAssertionBuilder; } try { synchronized (Runtime.getRuntime().getClass()) { samlAssertionBuilder = (SAMLAssertionBuilder) Class.forName(assertionBuilderClass).newInstance(); samlAssertionBuilder.init(); } return samlAssertionBuilder.buildAssertion(authReqDTO, notOnOrAfter, sessionId); } catch (ClassNotFoundException e) { throw new IdentityException("Class not found: " + assertionBuilderClass, e); } catch (InstantiationException e) { throw new IdentityException("Error while instantiating class: " + assertionBuilderClass, e); } catch (IllegalAccessException e) { throw new IdentityException("Illegal access to class: " + assertionBuilderClass, e); } catch (Exception e) { throw new IdentityException("Error while building the saml assertion", e); } } public static String createID() { byte[] bytes = new byte[20]; // 160 bits random.nextBytes(bytes); char[] chars = new char[40]; for (int i = 0; i < bytes.length; i++) { int left = (bytes[i] >> 4) & 0x0f; int right = bytes[i] & 0x0f; chars[i * 2] = charMapping[left]; chars[i * 2 + 1] = charMapping[right]; } return String.valueOf(chars); } /** * Generate the key store name from the domain name * * @param tenantDomain tenant domain name * @return key store file name */ public static String generateKSNameFromDomainName(String tenantDomain) { String ksName = tenantDomain.trim().replace(".", "-"); return ksName + ".jks"; } /** * Get the X509CredentialImpl object for a particular tenant * * @param tenantDomain * @param alias * @return X509CredentialImpl object containing the public certificate of * that tenant * @throws org.wso2.carbon.identity.sso.saml.exception.IdentitySAML2SSOException Error when creating X509CredentialImpl object */ public static X509CredentialImpl getX509CredentialImplForTenant(String tenantDomain, String alias) throws IdentitySAML2SSOException { if (tenantDomain == null || tenantDomain.trim().isEmpty() || alias == null || alias.trim().isEmpty()) { throw new IllegalArgumentException( "Invalid parameters; domain name : " + tenantDomain + ", " + "alias : " + alias); } int tenantId; try { tenantId = realmService.getTenantManager().getTenantId(tenantDomain); } catch (org.wso2.carbon.user.api.UserStoreException e) { String errorMsg = "Error getting the tenant ID for the tenant domain : " + tenantDomain; throw new IdentitySAML2SSOException(errorMsg, e); } KeyStoreManager keyStoreManager; // get an instance of the corresponding Key Store Manager instance keyStoreManager = KeyStoreManager.getInstance(tenantId); X509CredentialImpl credentialImpl = null; KeyStore keyStore; try { if (tenantId != -1234) {// for tenants, load private key from their generated key store keyStore = keyStoreManager.getKeyStore(generateKSNameFromDomainName(tenantDomain)); } else { // for super tenant, load the default pub. cert using the // config. in carbon.xml keyStore = keyStoreManager.getPrimaryKeyStore(); } java.security.cert.X509Certificate cert = (java.security.cert.X509Certificate) keyStore .getCertificate(alias); credentialImpl = new X509CredentialImpl(cert); } catch (Exception e) { String errorMsg = "Error instantiating an X509CredentialImpl object for the public certificate of " + tenantDomain; throw new IdentitySAML2SSOException(errorMsg, e); } return credentialImpl; } /** * Validates the request message's signature. Validates the signature of * both HTTP POST Binding and HTTP Redirect Binding. * * @param authnReqDTO * @return */ public static boolean validateAuthnRequestSignature(SAMLSSOAuthnReqDTO authnReqDTO) { if (log.isDebugEnabled()) { log.debug("Validating SAML Request signature"); } String domainName = authnReqDTO.getTenantDomain(); if (authnReqDTO.isStratosDeployment()) { domainName = MultitenantConstants.SUPER_TENANT_DOMAIN_NAME; } String alias = authnReqDTO.getCertAlias(); RequestAbstractType request = null; try { String decodedReq = null; if (authnReqDTO.getQueryString() != null) { decodedReq = SAMLSSOUtil.decode(authnReqDTO.getRequestMessageString()); } else { decodedReq = SAMLSSOUtil.decodeForPost(authnReqDTO.getRequestMessageString()); } request = (RequestAbstractType) SAMLSSOUtil.unmarshall(decodedReq); } catch (IdentityException e) { if (log.isDebugEnabled()) { log.debug( "Signature Validation failed for the SAMLRequest : Failed to unmarshall the SAML Assertion", e); } } try { if (authnReqDTO.getQueryString() != null) { // DEFLATE signature in Redirect Binding return validateDeflateSignature(authnReqDTO.getQueryString(), authnReqDTO.getIssuer(), alias, domainName); } else { // XML signature in SAML Request message for POST Binding return validateXMLSignature(request, alias, domainName); } } catch (IdentityException e) { if (log.isDebugEnabled()) { log.debug("Signature Validation failed for the SAMLRequest : Failed to validate the SAML Assertion", e); } return false; } } /** * Validates the signature of the LogoutRequest message. * TODO : for stratos deployment, super tenant key should be used * * @param logoutRequest * @param alias * @param subject * @param httpRequest * @param isHTTPRedirectBinding * @return */ public static boolean validateLogoutRequestSignature(LogoutRequest logoutRequest, String alias, String subject, String queryString) throws IdentityException { String domainName = getTenantDomainFromThreadLocal(); if (queryString != null) { return validateDeflateSignature(queryString, logoutRequest.getIssuer().getValue(), alias, domainName); } else { return validateXMLSignature(logoutRequest, alias, domainName); } } /** * Signature validation for HTTP Redirect Binding * * @param authnReqDTO * @param samlRequest * @param alias * @param domainName * @return */ public static boolean validateDeflateSignature(String queryString, String issuer, String alias, String domainName) throws IdentityException { try { synchronized (Runtime.getRuntime().getClass()) { samlHTTPRedirectSignatureValidator = (SAML2HTTPRedirectSignatureValidator) Class .forName(IdentityUtil.getProperty("SSOService.SAML2HTTPRedirectSignatureValidator").trim()) .newInstance(); samlHTTPRedirectSignatureValidator.init(); } return samlHTTPRedirectSignatureValidator.validateSignature(queryString, issuer, alias, domainName); } catch (SecurityException e) { log.error("Error validating deflate signature", e); return false; } catch (IdentitySAML2SSOException e) { log.warn( "Signature validation failed for the SAML Message : Failed to construct the X509CredentialImpl for the alias " + alias, e); return false; } catch (ClassNotFoundException e) { throw new IdentityException("Class not found: " + IdentityUtil.getProperty("SSOService.SAML2HTTPRedirectSignatureValidator"), e); } catch (InstantiationException e) { throw new IdentityException("Error while instantiating class: " + IdentityUtil.getProperty("SSOService.SAML2HTTPRedirectSignatureValidator"), e); } catch (IllegalAccessException e) { throw new IdentityException("Illegal access to class: " + IdentityUtil.getProperty("SSOService.SAML2HTTPRedirectSignatureValidator"), e); } } /** * Validate the signature of an assertion * * @param request SAML Assertion, this could be either a SAML Request or a * LogoutRequest * @param alias Certificate alias against which the signature is validated. * @param domainName domain name of the subject * @return true, if the signature is valid. */ public static boolean validateXMLSignature(RequestAbstractType request, String alias, String domainName) throws IdentityException { boolean isSignatureValid = false; if (request.getSignature() != null) { try { X509Credential cred = SAMLSSOUtil.getX509CredentialImplForTenant(domainName, alias); synchronized (Runtime.getRuntime().getClass()) { ssoSigner = (SSOSigner) Class .forName(IdentityUtil.getProperty("SSOService.SAMLSSOSigner").trim()).newInstance(); ssoSigner.init(); } return ssoSigner.validateXMLSignature(request, cred, alias); } catch (IdentitySAML2SSOException e) { if (log.isDebugEnabled()) { log.debug( "Signature validation failed for the SAML Message : Failed to construct the X509CredentialImpl for the alias " + alias, e); } } catch (IdentityException e) { if (log.isDebugEnabled()) { log.debug("Signature Validation Failed for the SAML Assertion : Signature is invalid.", e); } } catch (ClassNotFoundException e) { throw new IdentityException( "Class not found: " + IdentityUtil.getProperty("SSOService.SAMLSSOSigner"), e); } catch (InstantiationException e) { throw new IdentityException( "Error while instantiating class: " + IdentityUtil.getProperty("SSOService.SAMLSSOSigner"), e); } catch (IllegalAccessException e) { throw new IdentityException( "Illegal access to class: " + IdentityUtil.getProperty("SSOService.SAMLSSOSigner"), e); } catch (Exception e) { if (log.isDebugEnabled()) { log.debug("Error while validating XML signature.", e); } } } return isSignatureValid; } /** * Return a Array of Claims containing requested attributes and values * * @param authnReqDTO * @return Map with attributes and values * @throws IdentityException */ public static Map<String, String> getAttributes(SAMLSSOAuthnReqDTO authnReqDTO) throws IdentityException { int index = 0; // trying to get the Service Provider Configurations SSOServiceProviderConfigManager spConfigManager = SSOServiceProviderConfigManager.getInstance(); SAMLSSOServiceProviderDO spDO = spConfigManager.getServiceProvider(authnReqDTO.getIssuer()); if (spDO == null) { IdentityPersistenceManager persistenceManager = IdentityPersistenceManager.getPersistanceManager(); Registry registry = (Registry) PrivilegedCarbonContext.getThreadLocalCarbonContext() .getRegistry(RegistryType.SYSTEM_CONFIGURATION); spDO = persistenceManager.getServiceProvider(registry, authnReqDTO.getIssuer()); } if (!authnReqDTO.isIdPInitSSOEnabled()) { AuthnRequestImpl request = null; try { request = (AuthnRequestImpl) SAMLSSOUtil .unmarshall(SAMLSSOUtil.decode(authnReqDTO.getRequestMessageString())); } catch (IdentityException e) { request = (AuthnRequestImpl) SAMLSSOUtil .unmarshall(SAMLSSOUtil.decodeForPost(authnReqDTO.getRequestMessageString())); if (log.isDebugEnabled()) { log.debug("Error while decoding authentication request.", e); } } if (request.getAttributeConsumingServiceIndex() == null) { //SP has not provide a AttributeConsumingServiceIndex in the authnReqDTO if (StringUtils.isNotBlank(spDO.getAttributeConsumingServiceIndex())) { if (spDO.isEnableAttributesByDefault()) { index = Integer.parseInt(spDO.getAttributeConsumingServiceIndex()); } else { return null; } } else { return null; } } else { //SP has provide a AttributeConsumingServiceIndex in the authnReqDTO index = request.getAttributeConsumingServiceIndex(); } } else { if (StringUtils.isNotBlank(spDO.getAttributeConsumingServiceIndex())) { if (spDO.isEnableAttributesByDefault()) { index = Integer.parseInt(spDO.getAttributeConsumingServiceIndex()); } else { return null; } } else { return null; } } /* * IMPORTANT : checking if the consumer index in the request matches the * given id to the SP */ if (spDO.getAttributeConsumingServiceIndex() == null || "".equals(spDO.getAttributeConsumingServiceIndex()) || index != Integer.parseInt(spDO.getAttributeConsumingServiceIndex())) { if (log.isDebugEnabled()) { log.debug("Invalid AttributeConsumingServiceIndex in AuthnRequest"); } return Collections.emptyMap(); } Map<String, String> claimsMap = new HashMap<String, String>(); if (authnReqDTO.getUser().getUserAttributes() != null) { for (Map.Entry<ClaimMapping, String> entry : authnReqDTO.getUser().getUserAttributes().entrySet()) { claimsMap.put(entry.getKey().getRemoteClaim().getClaimUri(), entry.getValue()); } } return claimsMap; } /** * build the error response * * @param id * @param statusCodes * @param statusMsg * @return decoded response * @throws IdentityException */ public static String buildErrorResponse(String id, List<String> statusCodes, String statusMsg, String destination) throws IdentityException { ErrorResponseBuilder respBuilder = new ErrorResponseBuilder(); Response response = respBuilder.buildResponse(id, statusCodes, statusMsg, destination); return SAMLSSOUtil.encode(SAMLSSOUtil.marshall(response)); } public static int getSAMLResponseValidityPeriod() { if (StringUtils.isNotBlank( IdentityUtil.getProperty(IdentityConstants.ServerConfig.SAML_RESPONSE_VALIDITY_PERIOD))) { return Integer.parseInt( IdentityUtil.getProperty(IdentityConstants.ServerConfig.SAML_RESPONSE_VALIDITY_PERIOD).trim()); } else { return 5; } } public static int getSingleLogoutRetryCount() { return singleLogoutRetryCount; } public static void setSingleLogoutRetryCount(int singleLogoutRetryCount) { SAMLSSOUtil.singleLogoutRetryCount = singleLogoutRetryCount; } public static long getSingleLogoutRetryInterval() { return singleLogoutRetryInterval; } public static void setSingleLogoutRetryInterval(long singleLogoutRetryInterval) { SAMLSSOUtil.singleLogoutRetryInterval = singleLogoutRetryInterval; } public static ResponseBuilder getResponseBuilder() { if (responseBuilderClassName == null || "".equals(responseBuilderClassName)) { return new DefaultResponseBuilder(); } else { try { // Bundle class loader will cache the loaded class and returned // the already loaded instance, hence calling this method // multiple times doesn't cost. Class clazz = Thread.currentThread().getContextClassLoader().loadClass(responseBuilderClassName); return (ResponseBuilder) clazz.newInstance(); } catch (ClassNotFoundException e) { log.error("Error while instantiating the SAMLResponseBuilder ", e); } catch (InstantiationException e) { log.error("Error while instantiating the SAMLResponseBuilder ", e); } catch (IllegalAccessException e) { log.error("Error while instantiating the SAMLResponseBuilder ", e); } } return null; } public static void setResponseBuilder(String responseBuilder) { responseBuilderClassName = responseBuilder; } /** * This check if the status code is 2XX, check value between 200 and 300 * * @param status * @return */ public static boolean isHttpSuccessStatusCode(int status) { return status >= 200 && status < 300; } public static boolean isHttpRedirectStatusCode(int status) { return status == 302 || status == 303; } public static String getUserNameFromOpenID(String openid) throws IdentityException { String caller = null; String path = null; URI uri = null; String contextPath = "/openid/"; try { uri = new URI(openid); path = uri.getPath(); } catch (URISyntaxException e) { throw new IdentityException("Invalid OpenID", e); } caller = path.substring(path.indexOf(contextPath) + contextPath.length(), path.length()); return caller; } /** * Find the OpenID corresponding to the given user name. * * @param userName User name * @return OpenID corresponding the given user name. * @throws org.wso2.carbon.identity.base.IdentityException */ public static String getOpenID(String userName) throws IdentityException { return generateOpenID(userName); } /** * Generate OpenID for a given user. * * @param user User * @return Generated OpenID * @throws org.wso2.carbon.identity.base.IdentityException */ public static String generateOpenID(String user) throws IdentityException { String openIDUserUrl = null; String openID = null; URI uri = null; URL url = null; openIDUserUrl = IdentityUtil.getProperty(IdentityConstants.ServerConfig.OPENID_USER_PATTERN); user = normalizeUrlEncoding(user); openID = openIDUserUrl + user; try { uri = new URI(openID); } catch (URISyntaxException e) { throw new IdentityException("Invalid OpenID URL :" + openID, e); } try { url = uri.normalize().toURL(); if (url.getQuery() != null || url.getRef() != null) { throw new IdentityException("Invalid user name for OpenID :" + openID); } } catch (MalformedURLException e) { throw new IdentityException("Malformed OpenID URL :" + openID, e); } openID = url.toString(); return openID; } private static String normalizeUrlEncoding(String text) { if (text == null) return null; int len = text.length(); StringBuilder normalized = new StringBuilder(len); for (int i = 0; i < len; i++) { char current = text.charAt(i); if (current == '%' && i < len - 2) { String percentCode = text.substring(i, i + 3).toUpperCase(); try { String str = URLDecoder.decode(percentCode, "ISO-8859-1"); char chr = str.charAt(0); if (UNRESERVED_CHARACTERS.contains(Character.valueOf(chr))) normalized.append(chr); else normalized.append(percentCode); } catch (UnsupportedEncodingException e) { normalized.append(percentCode); if (log.isDebugEnabled()) { log.debug("Unsupported Encoding exception while decoding percent code.", e); } } i += 2; } else { normalized.append(current); } } return normalized.toString(); } public static void removeSession(String sessionId, String issuer) { SSOSessionPersistenceManager ssoSessionPersistenceManager = SSOSessionPersistenceManager .getPersistenceManager(); String sessionIndex = ssoSessionPersistenceManager.getSessionIndexFromTokenId(sessionId); SSOSessionPersistenceManager.removeSessionInfoDataFromCache(sessionIndex); SSOSessionPersistenceManager.removeSessionIndexFromCache(sessionId); } public static void setTenantDomainInThreadLocal(String tenantDomain) throws UserStoreException, IdentityException { tenantDomain = validateTenantDomain(tenantDomain); if (tenantDomain != null) { SAMLSSOUtil.tenantDomainInThreadLocal.set(tenantDomain); } } public static String getTenantDomainFromThreadLocal() { if (SAMLSSOUtil.tenantDomainInThreadLocal == null) { // this is the default behavior. return null; } return (String) SAMLSSOUtil.tenantDomainInThreadLocal.get(); } public static void removeTenantDomainFromThreadLocal() { SAMLSSOUtil.tenantDomainInThreadLocal.remove(); } public static String validateTenantDomain(String tenantDomain) throws UserStoreException, IdentityException { if (tenantDomain != null && !tenantDomain.trim().isEmpty() && !"null".equalsIgnoreCase(tenantDomain.trim())) { int tenantID = SAMLSSOUtil.getRealmService().getTenantManager().getTenantId(tenantDomain); if (tenantID == -1) { String message = "Invalid tenant domain : " + tenantDomain; if (log.isDebugEnabled()) { log.debug(message); } throw new IdentityException(message); } else { return tenantDomain; } } return null; } /** * build the error response * * @param status * @param message * @return decoded response * @throws org.wso2.carbon.identity.base.IdentityException */ public static String buildErrorResponse(String status, String message, String destination) throws IdentityException, IOException { ErrorResponseBuilder respBuilder = new ErrorResponseBuilder(); List<String> statusCodeList = new ArrayList<String>(); statusCodeList.add(status); Response response = respBuilder.buildResponse(null, statusCodeList, message, destination); String resp = SAMLSSOUtil.marshall(response); return compressResponse(resp); } /** * Compresses the response String * * @param response * @return * @throws IOException */ public static String compressResponse(String response) throws IOException { Deflater deflater = new Deflater(Deflater.DEFLATED, true); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(byteArrayOutputStream, deflater); try { deflaterOutputStream.write(response.getBytes(StandardCharsets.UTF_8)); return Base64.encodeBytes(byteArrayOutputStream.toByteArray(), Base64.DONT_BREAK_LINES); } finally { deflaterOutputStream.close(); } } }