Java tutorial
/* * Copyright 2014-2017 Groupon, Inc * Copyright 2014-2017 The Billing Project, LLC * * The Billing Project licenses this file to you under the Apache License, version 2.0 * (the "License"); you may not use this file except in compliance with the * License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package org.killbill.billing.util.security.shiro.realm; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.apache.shiro.authc.AuthenticationException; import org.apache.shiro.authc.AuthenticationInfo; import org.apache.shiro.authc.AuthenticationToken; import org.apache.shiro.authc.SimpleAuthenticationInfo; import org.apache.shiro.authc.UsernamePasswordToken; import org.apache.shiro.authz.AuthorizationException; import org.apache.shiro.authz.AuthorizationInfo; import org.apache.shiro.authz.SimpleAuthorizationInfo; import org.apache.shiro.config.Ini; import org.apache.shiro.config.Ini.Section; import org.apache.shiro.realm.AuthorizingRealm; import org.apache.shiro.subject.PrincipalCollection; import org.killbill.billing.util.config.definition.SecurityConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.ning.http.client.AsyncCompletionHandler; import com.ning.http.client.AsyncHttpClient; import com.ning.http.client.AsyncHttpClient.BoundRequestBuilder; import com.ning.http.client.AsyncHttpClientConfig; import com.ning.http.client.ListenableFuture; import com.ning.http.client.Response; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.inject.Inject; public class KillBillOktaRealm extends AuthorizingRealm { private static final Logger log = LoggerFactory.getLogger(KillBillOktaRealm.class); private static final ObjectMapper mapper = new ObjectMapper(); private static final int DEFAULT_TIMEOUT_SECS = 15; private static final Splitter SPLITTER = Splitter.on(',').omitEmptyStrings().trimResults(); private final Map<String, Collection<String>> permissionsByGroup = Maps.newLinkedHashMap(); private final SecurityConfig securityConfig; private final AsyncHttpClient httpClient; @Inject public KillBillOktaRealm(final SecurityConfig securityConfig) { this.securityConfig = securityConfig; this.httpClient = new AsyncHttpClient( new AsyncHttpClientConfig.Builder().setRequestTimeout(DEFAULT_TIMEOUT_SECS * 1000).build()); if (securityConfig.getShiroOktaPermissionsByGroup() != null) { final Ini ini = new Ini(); // When passing properties on the command line, \n can be escaped ini.load(securityConfig.getShiroOktaPermissionsByGroup().replace("\\n", "\n")); for (final Section section : ini.getSections()) { for (final String role : section.keySet()) { final Collection<String> permissions = ImmutableList .<String>copyOf(SPLITTER.split(section.get(role))); permissionsByGroup.put(role, permissions); } } } } @Override protected AuthorizationInfo doGetAuthorizationInfo(final PrincipalCollection principals) { final String username = (String) getAvailablePrincipal(principals); final String userId = findOktaUserId(username); final Set<String> userGroups = findOktaGroupsForUser(userId); final SimpleAuthorizationInfo simpleAuthorizationInfo = new SimpleAuthorizationInfo(userGroups); final Set<String> stringPermissions = groupsPermissions(userGroups); simpleAuthorizationInfo.setStringPermissions(stringPermissions); return simpleAuthorizationInfo; } @Override protected AuthenticationInfo doGetAuthenticationInfo(final AuthenticationToken token) throws AuthenticationException { final UsernamePasswordToken upToken = (UsernamePasswordToken) token; if (doAuthenticate(upToken)) { // Credentials are valid return new SimpleAuthenticationInfo(token.getPrincipal(), token.getCredentials(), getName()); } else { throw new AuthenticationException("Okta authentication failed"); } } private boolean doAuthenticate(final UsernamePasswordToken upToken) { final BoundRequestBuilder builder = httpClient .preparePost(securityConfig.getShiroOktaUrl() + "/api/v1/authn"); try { final ImmutableMap<String, String> body = ImmutableMap.<String, String>of("username", upToken.getUsername(), "password", String.valueOf(upToken.getPassword())); builder.setBody(mapper.writeValueAsString(body)); } catch (final JsonProcessingException e) { log.warn("Error while generating Okta payload"); throw new AuthenticationException(e); } builder.addHeader("Authorization", "SSWS " + securityConfig.getShiroOktaAPIToken()); builder.addHeader("Content-Type", "application/json; charset=UTF-8"); final Response response; try { final ListenableFuture<Response> futureStatus = builder.execute(new AsyncCompletionHandler<Response>() { @Override public Response onCompleted(final Response response) throws Exception { return response; } }); response = futureStatus.get(DEFAULT_TIMEOUT_SECS, TimeUnit.SECONDS); } catch (final TimeoutException toe) { log.warn("Timeout while connecting to Okta"); throw new AuthenticationException(toe); } catch (final Exception e) { log.warn("Error while connecting to Okta"); throw new AuthenticationException(e); } return isAuthenticated(response); } private boolean isAuthenticated(final Response oktaRawResponse) { try { final Map oktaResponse = mapper.readValue(oktaRawResponse.getResponseBodyAsStream(), Map.class); if ("SUCCESS".equals(oktaResponse.get("status"))) { return true; } else { log.warn("Okta authentication failed: " + oktaResponse); return false; } } catch (final IOException e) { log.warn("Unable to read response from Okta"); throw new AuthenticationException(e); } } private String findOktaUserId(final String login) { final String path; try { path = "/api/v1/users/" + URLEncoder.encode(login, "UTF-8"); } catch (final UnsupportedEncodingException e) { // Should never happen throw new IllegalStateException(e); } final Response oktaRawResponse = doGetRequest(path); try { final Map oktaResponse = mapper.readValue(oktaRawResponse.getResponseBodyAsStream(), Map.class); return (String) oktaResponse.get("id"); } catch (final IOException e) { log.warn("Unable to read response from Okta"); throw new AuthorizationException(e); } } private Set<String> findOktaGroupsForUser(final String userId) { final String path = "/api/v1/users/" + userId + "/groups"; final Response response = doGetRequest(path); return getGroups(response); } private Response doGetRequest(final String path) { final BoundRequestBuilder builder = httpClient.prepareGet(securityConfig.getShiroOktaUrl() + path); builder.addHeader("Authorization", "SSWS " + securityConfig.getShiroOktaAPIToken()); builder.addHeader("Content-Type", "application/json; charset=UTF-8"); final Response response; try { final ListenableFuture<Response> futureStatus = builder.execute(new AsyncCompletionHandler<Response>() { @Override public Response onCompleted(final Response response) throws Exception { return response; } }); response = futureStatus.get(DEFAULT_TIMEOUT_SECS, TimeUnit.SECONDS); } catch (final TimeoutException toe) { log.warn("Timeout while connecting to Okta"); throw new AuthorizationException(toe); } catch (final Exception e) { log.warn("Error while connecting to Okta"); throw new AuthorizationException(e); } return response; } private Set<String> getGroups(final Response oktaRawResponse) { try { final List<Map> oktaResponse = mapper.readValue(oktaRawResponse.getResponseBodyAsStream(), new TypeReference<List<Map>>() { }); final Set<String> groups = new HashSet<String>(); for (final Map group : oktaResponse) { final Object groupProfile = group.get("profile"); if (groupProfile != null && groupProfile instanceof Map) { groups.add((String) ((Map) groupProfile).get("name")); } } return groups; } catch (final IOException e) { log.warn("Unable to read response from Okta"); throw new AuthorizationException(e); } } private Set<String> groupsPermissions(final Iterable<String> groups) { final Set<String> permissions = new HashSet<String>(); for (final String group : groups) { final Collection<String> permissionsForGroup = permissionsByGroup.get(group); if (permissionsForGroup != null) { permissions.addAll(permissionsForGroup); } } return permissions; } }