Java tutorial
/* * Copyright (c) 2012-2015 VMware, Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy * of the License at http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, without * warranties or conditions of any kind, EITHER EXPRESS OR IMPLIED. See the * License for the specific language governing permissions and limitations * under the License. */ package com.vmware.identity.samlservice.impl; import java.security.PrivateKey; import java.security.cert.CertPath; import java.security.cert.Certificate; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.EnumSet; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; import org.apache.commons.lang.Validate; import org.w3c.dom.Document; import com.vmware.identity.diagnostics.DiagnosticsLoggerFactory; import com.vmware.identity.diagnostics.IDiagnosticsLogger; import com.vmware.identity.idm.AssertionConsumerService; import com.vmware.identity.idm.AuthnPolicy; import com.vmware.identity.idm.DomainType; import com.vmware.identity.idm.GSSResult; import com.vmware.identity.idm.Group; import com.vmware.identity.idm.IDMSecureIDNewPinException; import com.vmware.identity.idm.IDPConfig; import com.vmware.identity.idm.IIdentityStoreData; import com.vmware.identity.idm.PersonDetail; import com.vmware.identity.idm.PrincipalId; import com.vmware.identity.idm.RSAAMResult; import com.vmware.identity.idm.RelyingParty; import com.vmware.identity.idm.SSOImplicitGroupNames; import com.vmware.identity.idm.ServiceEndpoint; import com.vmware.identity.idm.TokenClaimAttribute; import com.vmware.identity.idm.client.CasIdmClient; import com.vmware.identity.saml.InvalidTokenException; import com.vmware.identity.saml.ServerValidatableSamlToken.Subject; import com.vmware.identity.samlservice.IdmAccessor; import com.vmware.identity.samlservice.Shared; import com.vmware.identity.websso.client.Attribute; /** * IDM Accessor class which talks to real IDM system (by wrapping CasIdmClient) * */ public class CasIdmAccessor implements IdmAccessor { private static final IDiagnosticsLogger logger = DiagnosticsLoggerFactory.getLogger(CasIdmAccessor.class); private final CasIdmClient client; private String tenant; private static final char[] invalidCharsForUserName; static { char[] invalidCharsForUserDetail = "^<>&%`".toCharArray(); char upnSeparator = '@'; char netbiosSeparator = '\\'; invalidCharsForUserName = (String.valueOf(invalidCharsForUserDetail) + upnSeparator + netbiosSeparator) .toCharArray(); } /** * Create IDM Accessor with an instance of the IDM client * * @param idmClient */ public CasIdmAccessor(CasIdmClient idmClient) { logger.debug("CasIdmAccessor constructor called"); Validate.notNull(idmClient); client = idmClient; } /* * (non-Javadoc) * * @see com.vmware.identity.samlservice.IdmAccessor#setDefaultTenant() */ @Override public void setDefaultTenant() { logger.debug("setDefaultTenant called"); try { String defaultTenant = client.getDefaultTenant(); Validate.notNull(defaultTenant); setTenant(defaultTenant); } catch (Exception e) { logger.error("setDefaultTenant: Caught exception {}", e.toString()); throw new IllegalStateException("BadRequest", e); } } /* * (non-Javadoc) * * @see * com.vmware.identity.samlservice.IdmAccessor#setTenant(java.lang.String) */ @Override public void setTenant(String t) { logger.debug("setTenant: {}", t); tenant = t; } /* * (non-Javadoc) * * @see com.vmware.identity.samlservice.IdmAccessor#getTenant() */ @Override public String getTenant() { logger.debug("getTenant: {}", tenant); return tenant; } /* * (non-Javadoc) * * @see com.vmware.identity.samlservice.IdmAccessor#getIdpEntityId() */ @Override public String getIdpEntityId() { logger.debug("getIdpEntityId"); String retval = null; try { retval = client.getEntityID(tenant); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException("BadRequest", e); } return retval; } /* * (non-Javadoc) * * @see * com.vmware.identity.samlservice.IdmAccessor#getAcsForRelyingParty(java * .lang.String, int, java.lang.String) */ @Override public String getAcsForRelyingParty(String relyingParty, Integer acsIndex, String acsUrl, String binding, boolean validateWithMetadata) { logger.debug("getAcsForRelyingParty " + relyingParty + ", index " + acsIndex + ", URL " + acsUrl + ", binding " + binding); String retval = null; try { RelyingParty rp = client.getRelyingPartyByUrl(tenant, relyingParty); Validate.notNull(rp); Collection<AssertionConsumerService> assertionServices = rp.getAssertionConsumerServices(); Validate.notNull(assertionServices); Validate.isTrue(assertionServices.size() > 0); AssertionConsumerService[] services = assertionServices .toArray(new AssertionConsumerService[assertionServices.size()]); if (acsIndex != null) { // if index is present, URL must not be if (acsUrl != null) { throw new IllegalStateException("BadRequest.AssertionIndex"); } if (acsIndex < 0 || acsIndex >= assertionServices.size()) { throw new IllegalStateException("BadRequest.AssertionIndex"); } Validate.notNull(services[acsIndex]); retval = services[acsIndex].getEndpoint(); } else if (acsUrl != null) { // we have no index specified, URL is present // find assertion consumer service by URL if (validateWithMetadata) { for (AssertionConsumerService acs : assertionServices) { if (acs != null && acsUrl.equals(acs.getEndpoint())) { // check binding if specified if (binding == null || (binding != null && binding.equals(acs.getBinding()))) { retval = acs.getEndpoint(); } } } } else { // no validation retval = acsUrl; } // by now we should have found something if (retval == null) { throw new IllegalStateException("BadRequest.AssertionMetadata"); } } else if (binding != null) { // we have to index or URL specified, lookup by binding for (AssertionConsumerService acs : assertionServices) { if (acs != null && binding.equals(acs.getBinding())) { retval = acs.getEndpoint(); } } // by now we should have found something if (retval == null) { throw new IllegalStateException("BadRequest.AssertionBinding"); } } else { // just look for the default service if any for (AssertionConsumerService acs : assertionServices) { if (acs != null && acs.getName() != null && acs.getName().equals(rp.getDefaultAssertionConsumerService())) { retval = acs.getEndpoint(); } } // by now we should have found something if (retval == null) { throw new IllegalStateException("BadRequest.AssertionNoDefault"); } } } catch (IllegalStateException e) { logger.error("Caught illegal state exception {}", e.toString()); throw e; } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException("BadRequest", e); } Validate.notNull(retval); return retval; } /* * (non-Javadoc) * * @see * com.vmware.identity.samlservice.IdmAccessor#getCertificatesForRelyingParty * (java.lang.String) */ @Override public CertPath getCertificatesForRelyingParty(String relyingParty) { logger.debug("getCertificatesForRelyingParty {}", relyingParty); List<X509Certificate> certificates = new ArrayList<X509Certificate>(); // only query relying party if it's not null // simply return an empty chain for 'null' relying party if (relyingParty != null) { try { // TODO support more than one cert RelyingParty rp = client.getRelyingPartyByUrl(tenant, relyingParty); Validate.notNull(rp); Certificate c = rp.getCertificate(); Validate.notNull(c); certificates.add((X509Certificate) c); } catch (Exception e) { logger.error("Caught exception ", e); } } try { CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); CertPath certPath = certFactory.generateCertPath(certificates); return certPath; } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException(e); } } @Override public List<Certificate> getSAMLAuthorityChain() { logger.debug("getSAMLAuthorityChain"); List<Certificate> certs = Collections.emptyList(); try { certs = client.getTenantCertificate(tenant); } catch (Exception e) { logger.error("Caught exception ", e); } Validate.notEmpty(certs); return certs; } @Override public Collection<List<Certificate>> getSAMLAuthorityChains() { logger.debug("getSAMLAuthorityChains"); Collection<List<Certificate>> allChains = Collections.emptyList(); try { allChains = client.getTenantCertificates(tenant); } catch (Exception e) { logger.error("Caught exception ", e); } Validate.notEmpty(allChains); return allChains; } @Override public PrivateKey getSAMLAuthorityPrivateKey() { logger.debug("getSAMLAuthorityPrivateKey"); try { return client.getTenantPrivateKey(tenant); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException(e); } } @Override public long getClockTolerance() { logger.debug("getClockTolerance"); try { return client.getClockTolerance(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(e); } } @Override public String getTenantSignatureAlgorithm() { logger.debug("getTenantSignatureAlgorithm"); try { return client.getTenantSignatureAlgorithm(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(e); } } @Override public long getMaximumBearerTokenLifetime() { logger.debug("getMaximumBearerTokenLifetime"); try { return client.getMaximumBearerTokenLifetime(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(e); } } @Override public long getMaximumHoKTokenLifetime() { logger.debug("getMaximumHoKTokenLifetime"); try { return client.getMaximumHoKTokenLifetime(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(e); } } @Override public int getDelegationCount() { logger.debug("getDelegationCount"); try { return client.getDelegationCount(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(e); } } @Override public int getRenewCount() { logger.debug("getRenewCount"); try { return client.getRenewCount(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(e); } } @Override public void incrementGeneratedTokens(String tenant) { logger.debug("incrementGeneratedTokens"); try { client.incrementGeneratedTokens(tenant); } catch (Exception e) { logger.error("Caught exception ", e); } } @Override public GSSResult authenticate(String contextId, byte[] decodedAuthData) { logger.debug("kerb authenticate"); try { return client.authenticate(tenant, contextId, decodedAuthData); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException(e); } } @Override public String getDefaultIdpEntityId() { logger.debug("getDefaultIdpEntityId"); String retval = null; try { retval = getIdpEntityId(); if (retval.endsWith(tenant)) { // effectively trim "/{tenant} from the end retval = retval.substring(0, retval.length() - tenant.length() - 1); } } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException("BadRequest", e); } return retval; } @Override public CasIdmClient getIdmClient() { return client; } @Override public boolean getAuthnRequestsSignedForRelyingParty(String relyingParty) { logger.debug("getAuthnRequestsSignedForRelyingParty " + relyingParty); boolean retval = false; try { RelyingParty rp = client.getRelyingPartyByUrl(tenant, relyingParty); Validate.notNull(rp); retval = rp.isAuthnRequestsSigned(); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException(e); } return retval; } @Override public PrincipalId authenticate(String username, String password) { logger.debug("password authenticate"); try { return client.authenticate(tenant, username, password); } catch (Exception e) { logger.error("Caught exception. ", e); throw new IllegalStateException(e); } } @Override public PrincipalId authenticate(X509Certificate[] tLSCertChain) { try { return client.authenticate(tenant, tLSCertChain); } catch (Exception e) { logger.error("Caught exception. ", e); throw new IllegalStateException(e); } } @Override public RSAAMResult authenticatebyPasscode(String rsaSessionId, String username, String passcode) throws IDMSecureIDNewPinException { logger.debug("rsa secureID authenticate"); try { return client.authenticateRsaSecurId(tenant, rsaSessionId, username, passcode); } catch (IDMSecureIDNewPinException e) { logger.error("New pin required.", e); throw e; } catch (Exception e) { logger.error("Caught exception. ", e); throw new IllegalStateException(e); } } @Override public String getSloForRelyingParty(String relyingParty, String binding) throws IllegalStateException { logger.debug("getSloForRelyingParty " + relyingParty + ", binding " + binding); String retval = null; Validate.notNull(binding); try { RelyingParty rp = client.getRelyingPartyByUrl(tenant, relyingParty); Validate.notNull(rp); Collection<ServiceEndpoint> sloServices = rp.getSingleLogoutServices(); // SLO service is optional and if it does not exist or binding does not match, return null. if (sloServices != null && sloServices.size() > 0) { // lookup by binding for (ServiceEndpoint slo : sloServices) { if (slo != null && binding.equals(slo.getBinding())) { retval = slo.getResponseEndpoint(); if (retval == null || retval.isEmpty()) { retval = slo.getEndpoint(); } } } // by now we should have found something if (retval == null) { logger.warn(String.format( "SLO service for relying party %s exists, but does not support %s binding.", relyingParty, binding)); } } else { logger.warn(String.format("SLO service for relying party %s does not exist.", relyingParty)); } return retval; } catch (IllegalStateException e) { throw e; } catch (Exception e) { throw new IllegalStateException("BadRequest", e); } } @Override public String exportConfigurationAsString() { logger.debug("export configuration"); try { Document doc = client.getSsoSaml2Metadata(tenant); return Shared.getStringFromDocument(doc); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(e); } } @Override public String getIdpSsoEndpoint() { logger.debug("getIdpSsoEndpoint"); String retval = null; try { retval = client.getEntityID(tenant).replace("/Metadata", "/SSO"); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException("BadRequest", e); } return retval; } @Override public String getIdpSloEndpoint() { logger.debug("getIdpSloEndpoint"); String retval = null; try { retval = client.getEntityID(tenant).replace("/Metadata", "/SLO"); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException("BadRequest", e); } return retval; } @Override public String getDefaultIdpSsoEndpoint() { logger.debug("getDefaultIdpEntityId"); String retval = null; try { retval = getIdpEntityId(); if (retval.endsWith(tenant)) { // effectively trim "/{tenant} from the end retval = retval.substring(0, retval.length() - tenant.length() - 1); } // change to SSO endpoint retval = retval.replace("/Metadata", "/SSO"); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException("BadRequest", e); } return retval; } @Override public Collection<IDPConfig> getExternalIdps() { logger.debug("getExternalIdps"); Collection<IDPConfig> idps = Collections.emptyList(); try { idps = client.getAllExternalIdpConfig(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException("BadRequest", e); } return idps; } @Override public String getBrandName() { logger.debug("getBrandName"); try { return client.getBrandName(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException("Failed to return tenant brand name for: " + tenant, e); } } @Override public String getLogonBannerContent() { logger.debug("getLogonBannerContent"); try { return client.getLogonBannerContent(tenant); } catch (Exception e) { logger.debug("Caught exception " + e.toString()); throw new IllegalStateException("Failed to return tenant logon banner content for: " + tenant, e); } } @Override public String getLogonBannerTitle() { logger.debug("getLogonBannerTitle"); try { return client.getLogonBannerTitle(tenant); } catch (Exception e) { logger.debug("Caught exception " + e.toString()); throw new IllegalStateException("Failed to return tenant logon banner title for: " + tenant, e); } } @Override public boolean getLogonBannerCheckboxFlag() { logger.debug("getLogonBannerCheckboxFlag"); try { return client.getLogonBannerCheckboxFlag(tenant); } catch (Exception e) { logger.debug("Caught exception " + e.toString()); throw new IllegalStateException("Failed to return tenant logon banner checkbox for: " + tenant, e); } } @Override public List<Certificate> getTenantCertificate() { logger.debug("getTenantCertificate"); try { return client.getTenantCertificate(tenant); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException("Failed to return tenant signing cert for: " + tenant, e); } } @Override public IDPConfig getExternalIdpConfigForTenant(String tenant, String providerID) { try { return client.getExternalIdpConfigForTenant(tenant, providerID); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException("Failed to return external IDP configuration. provider: " + providerID, e); } } @Override public RelyingParty getRelyingPartyByUrl(String rpEntityId) { logger.debug("getRelyingPartyByUrl"); try { return client.getRelyingPartyByUrl(tenant, rpEntityId); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException("Failed to return find relying party: " + rpEntityId, e); } } @Override public String getServerSPN() { try { return client.getServerSPN(); } catch (Exception e) { logger.error("Caught exception ", e); throw new IllegalStateException("Failed to get server SPN", e); } } @Override public PrincipalId createUserAccountJustInTime(Subject subject, String tenant, IDPConfig extIdp) throws Exception { final String userNameDelimiter = "-"; if (subject == null) { throw new InvalidTokenException("The subject retrieved from external token is null."); } // retrieve system domain EnumSet<DomainType> domains = EnumSet.of(DomainType.SYSTEM_DOMAIN); Iterator<IIdentityStoreData> iter = client.getProviders(tenant, domains).iterator(); String systemDomain = iter.next().getName(); PrincipalId subjectUpn = null; String userName = null; String upnSuffix = null; String extUserId = null; if (subject.subjectUpn() != null) { subjectUpn = subject.subjectUpn(); upnSuffix = subjectUpn.getDomain(); // add upn suffix to user name for ext users to avoid conflict with local user with the same name userName = subjectUpn.getName() + userNameDelimiter + upnSuffix; extUserId = subjectUpn.getUPN(); } else { // to support non-upn subject format in external token String nameId = subject.subjectNameId().getName(); // compose user name as sanitizedExternalID.GUID userName = sanitizeSubjectNameIdForUserName(nameId) + userNameDelimiter + UUID.randomUUID().toString(); upnSuffix = extIdp.getUpnSuffix(); if (upnSuffix == null || upnSuffix.isEmpty()) { throw new IllegalStateException("UPN suffix is not set for external IDP: " + extIdp.getEntityID()); } subjectUpn = new PrincipalId(userName, upnSuffix); extUserId = nameId; } // register upn suffix to system domain client.registerUpnSuffix(tenant, systemDomain, upnSuffix); logger.info("Creating a temporary user account for the user {} with domain {} " + "in VMware identity store since the user is not found " + "during delegated logon via SAML IDP federation.", subjectUpn.getUPN(), tenant); return client.addJitUser(tenant, userName, new PersonDetail.Builder().userPrincipalName(subjectUpn.getUPN()) .description("A JIT user account created for external IDP.").build(), extIdp.getEntityID(), extUserId); } private String sanitizeSubjectNameIdForUserName(String nameId) { String sanitizedNameId = nameId; int pos = nameId.indexOf("@"); if (pos > 0) { sanitizedNameId = nameId.substring(0, pos); } for (char c : invalidCharsForUserName) { sanitizedNameId.replace(String.valueOf(c), "#"); } return sanitizedNameId; } @Override public void updateJitUserGroups(PrincipalId subjectUpn, String tenant, Map<TokenClaimAttribute, List<String>> mappings, Collection<Attribute> claimAttributes) throws Exception { Set<Group> currentGroups = client.findDirectParentGroups(tenant, subjectUpn); if (currentGroups == null) { currentGroups = new HashSet<>(); } Set<Group> newGroups = new HashSet<>(); if (mappings != null && claimAttributes != null) { for (Attribute attr : claimAttributes) { String attrName = attr.getName(); for (String attrValue : attr.getValues()) { TokenClaimAttribute tokenClaim = new TokenClaimAttribute(attrName, attrValue); List<String> groups = mappings.get(tokenClaim); if (groups == null || groups.isEmpty()) { continue; } for (String groupSid : groups) { try { newGroups.add(client.findGroupByObjectId(tenant, groupSid)); } catch (Exception e) { logger.error("Failed to find group with sid " + groupSid, e); } } } } for (Group g : newGroups) { if (!currentGroups.contains(g)) { try { client.addUserToGroup(tenant, subjectUpn, g.getName()); logger.debug("User {} added to group{}s in tenant {}", subjectUpn.getUPN(), g.getName(), tenant); } catch (Exception e) { logger.error(String.format( "Failed to add user %s to group %s in tenant %s. " + "Continue updating user group membership...", subjectUpn.getUPN(), g.getName(), tenant), e); } } } } for (Group g : currentGroups) { if (!newGroups.contains(g) && !g.getName().equalsIgnoreCase(SSOImplicitGroupNames.getEveryoneGroupName())) { try { client.removeFromLocalGroup(tenant, subjectUpn, g.getName()); } catch (Exception e) { logger.error(String.format( "Failed to remove user %s from group %s in tenant %s. " + "Continue updating user group membership...", subjectUpn.getUPN(), g.getName(), tenant), e); } } } } @Override public boolean isJitEnabledForExternalIdp(String tenantName, String entityId) { try { return client.getExternalIdpConfigForTenant(tenantName, entityId).getJitAttribute(); } catch (Exception e) { logger.debug("Caught exception ", e); throw new IllegalStateException(String .format("Failed to return jit attribute for idp: %s for tenant %s.", entityId, tenantName), e); } } @Override public AuthnPolicy getAuthnPolicy(String tenantName) { try { return client.getAuthnPolicy(tenantName); } catch (Exception e) { throw new IllegalStateException( String.format("Failed to return authentication policy object: for tenant %s.", tenantName), e); } } @Override public boolean getTenantIDPSelectionFlag(String tenantName) { try { return client.isTenantIDPSelectionEnabled(tenantName); } catch (Exception e) { throw new IllegalStateException( String.format("Failed to return idp selection flag: for tenant %s.", tenantName), e); } } @Override public String getIDPAlias(String tenantName, String entityId) { try { if (this.getIdpEntityId().equals(entityId)) { return client.getLocalIDPAlias(tenantName); } return client.getExternalIDPAlias(tenantName, entityId); } catch (Exception e) { throw new IllegalStateException( String.format("Failed to return idp display name: for idp %s tenant %s.", entityId, tenantName), e); } } @Override public Collection<String> getAllTenants() throws Exception { try { return client.getAllTenants(); } catch (Exception e) { throw new IllegalStateException("Failed to return all tenant names.", e); } } @Override public Collection<RelyingParty> getRelyingParties(String tenant) { try { return client.getRelyingParties(tenant); } catch (Exception e) { throw new IllegalStateException("Failed to return relying part configurations for the tenant.", e); } } }