Java tutorial
/* * Copyright 2016 Yoshio Terada * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.yoshio3.modules; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.jaxrs.json.JacksonJaxbJsonProvider; import com.microsoft.aad.adal4j.AuthenticationContext; import com.microsoft.aad.adal4j.AuthenticationResult; import com.microsoft.aad.adal4j.ClientCredential; import com.nimbusds.oauth2.sdk.AuthorizationCode; import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse; import com.nimbusds.openid.connect.sdk.AuthenticationResponse; import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser; import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse; import com.yoshio3.modules.entities.ADUserMemberOfGroups; import java.io.IOException; import java.io.StringWriter; import java.io.UnsupportedEncodingException; import java.net.URI; import java.net.URLEncoder; import java.security.Principal; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.logging.Level; import java.util.logging.Logger; import javax.json.Json; import javax.json.JsonObject; import javax.json.JsonWriter; import javax.naming.ServiceUnavailableException; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSession; import javax.security.auth.Subject; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.auth.login.LoginContext; import javax.security.auth.login.LoginException; import javax.security.auth.message.AuthException; import javax.security.auth.message.AuthStatus; import javax.security.auth.message.MessageInfo; import javax.security.auth.message.MessagePolicy; import javax.security.auth.message.callback.CallerPrincipalCallback; import javax.security.auth.message.callback.GroupPrincipalCallback; import javax.security.auth.message.module.ServerAuthModule; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.ws.rs.client.Client; import javax.ws.rs.client.ClientBuilder; import javax.ws.rs.client.Entity; import javax.ws.rs.core.Response; import org.glassfish.jersey.SslConfigurator; import org.glassfish.jersey.jackson.JacksonFeature; /** * * @author Yoshio Terada */ public class AzureADServerAuthModule implements ServerAuthModule { private static final Logger LOGGER = Logger.getLogger(AzureADServerAuthModule.class.getName()); public final static String ERROR = "error"; public final static String ERROR_DESCRIPTION = "error_description"; public final static String ERROR_URI = "error_uri"; public final static String ID_TOKEN = "id_token"; public final static String CODE = "code"; private static final String SAVED_SUBJECT = "saved_subject"; public static final String PRINCIPAL_SESSION_NAME = "principal"; private String authority = ""; private String tenant = ""; private String clientId = ""; private String secretKey = ""; private String graphServer = ""; private final static String LOGIN_CONTEXT_NAME = "AzureAD-Login"; //login.conf ????? //login.conf ????????? // the name mentioned in login.conf static final String AUTHORIZATION_HEADER = "authorization"; private static final Class[] SUPPORTED_MESSAGE_TYPE = new Class[] { HttpServletRequest.class, HttpServletResponse.class }; private MessagePolicy requestPolicy; private MessagePolicy responsePolicy; private CallbackHandler handler; private Map<String, String> options; private LoginContext loginContext = null; public MessagePolicy getRequestPolicy() { return requestPolicy; } public MessagePolicy getResponsePolicy() { return responsePolicy; } public Map<String, String> getOptions() { return options; } @Override public void initialize(MessagePolicy requestPolicy, MessagePolicy responsePolicy, CallbackHandler handler, Map options) throws AuthException { this.requestPolicy = requestPolicy; this.responsePolicy = responsePolicy; this.handler = handler; this.options = options; if (options == null) { return; } if (options.containsKey("authority")) { authority = (String) options.get("authority"); } if (options.containsKey("tenant")) { tenant = (String) options.get("tenant"); } if (options.containsKey("client_id")) { clientId = (String) options.get("client_id"); } if (options.containsKey("secret_key")) { secretKey = (String) options.get("secret_key"); } if (options.containsKey("graph_server")) { graphServer = (String) options.get("graph_server"); } } @Override public Class[] getSupportedMessageTypes() { return SUPPORTED_MESSAGE_TYPE; } @Override public AuthStatus validateRequest(MessageInfo messageInfo, Subject clientSubject, Subject serviceSubject) throws AuthException { HttpServletRequest httpRequest = (HttpServletRequest) messageInfo.getRequestMessage(); HttpServletResponse httpResponse = (HttpServletResponse) messageInfo.getResponseMessage(); Callback[] callbacks; //Azure AD ?????????? // if returning as a redirect after authenticating on Azure AD //??????????????? //?????????????????????????? // as there is no principal information, if authentication was successful add info to the principal Map<String, String> params = new HashMap<>(); httpRequest.getParameterMap().keySet().stream().forEach(key -> { params.put(key, httpRequest.getParameterMap().get(key)[0]); }); String currentUri = getCurrentUri(httpRequest); //????????? // if the authentication result is not included in the session if (!getSessionPrincipal(httpRequest)) { if (!isRedirectedRequestFromAuthServer(httpRequest, params)) { try { // Azure AD ? Redirect // redirect to Azure ID return redirectOpenIDServer(httpResponse, currentUri); } catch (IOException ex) { LOGGER.log(Level.SEVERE, "Invalid redirect URL", ex); return AuthStatus.SEND_FAILURE; } } else { // Azure AD ???????? // if it's a request returning from Azure AD messageInfo.getMap().put("javax.servlet.http.registerSession", Boolean.TRUE.toString()); messageInfo.getMap().put("javax.servlet.http.authType", "AzureADServerAuthModule"); return getAuthResultFromServerAndSetSession(clientSubject, httpRequest, params, currentUri); } } else { try { //??????? // if the authentication result is included in the session AzureADUserPrincipal sessionPrincipal = (AzureADUserPrincipal) httpRequest.getUserPrincipal(); AuthenticationResult authenticationResult = sessionPrincipal.getAuthenticationResult(); if (authenticationResult.getExpiresOnDate().before(new Date())) { //???????? // if the authentication date is old - get an access token from the refresh token AuthenticationResult authResult = getAccessTokenFromRefreshToken( authenticationResult.getRefreshToken(), currentUri); setSessionPrincipal(httpRequest, new AzureADUserPrincipal(authResult)); } CallerPrincipalCallback callerCallBack = new CallerPrincipalCallback(clientSubject, sessionPrincipal); String[] groups = getGroupList(sessionPrincipal); GroupPrincipalCallback groupPrincipalCallback = new GroupPrincipalCallback(clientSubject, groups); callbacks = new Callback[] { callerCallBack, groupPrincipalCallback }; handler.handle(callbacks); return AuthStatus.SUCCESS; } catch (Throwable ex) { LOGGER.log(Level.SEVERE, "Invalid Session Info", ex); return AuthStatus.SEND_FAILURE; } } } @Override public AuthStatus secureResponse(MessageInfo messageInfo, Subject serviceSubject) throws AuthException { return AuthStatus.SEND_SUCCESS; } @Override public void cleanSubject(MessageInfo messageInfo, Subject subject) throws AuthException { try { if (subject != null) { subject.getPrincipals().clear(); } loginContext.logout(); } catch (LoginException ex) { LOGGER.log(Level.SEVERE, null, ex); } } /* ????????? */ /* Step 1: when it's the initial unauthenticated request */ private AuthStatus redirectOpenIDServer(HttpServletResponse httpResponse, String currentUri) throws UnsupportedEncodingException, IOException { //?????????????? // if not authenticated, without any authentication data // ???????? Azure AD ???? // if it's not authenticated, redirect to Azure AD authentication screen String redirectUrl = getRedirectUrl(currentUri); httpResponse.setStatus(302); httpResponse.sendRedirect(getRedirectUrl(currentUri)); return AuthStatus.SEND_CONTINUE; } /* ? code, id_token ?????? ??????? */ /* Step 2: when there is a code, id_token etc after being redirected at this point, save the authentication result into the session */ private AuthStatus getAuthResultFromServerAndSetSession(Subject clientSubject, HttpServletRequest httpRequest, Map<String, String> params, String currentUri) { try { String fullUrl = currentUri + (httpRequest.getQueryString() != null ? "?" + httpRequest.getQueryString() : ""); AuthenticationResponse authResponse = AuthenticationResponseParser.parse(new URI(fullUrl), params); //params ?? error ???????AuthenticationErrorResponse // if there is an error key in params, return AuthenticationErrorResponse //??? AuthenticationSuccessResponse ? // if it was successful, return AuthenticationSuccessResponse //?????? // if authentication was successful if (authResponse instanceof AuthenticationSuccessResponse) { //??????? // obtain the result from the response and save it in the session AuthenticationSuccessResponse authSuccessResponse = (AuthenticationSuccessResponse) authResponse; AuthenticationResult result = getAccessToken(authSuccessResponse.getAuthorizationCode(), currentUri); AzureADUserPrincipal userPrincipal = new AzureADUserPrincipal(result); setSessionPrincipal(httpRequest, userPrincipal); //? // set the user principal String[] groups = getGroupList(userPrincipal); System.out.println(": " + Arrays.toString(groups)); AzureADCallbackHandler azureCallBackHandler = new AzureADCallbackHandler(clientSubject, httpRequest, userPrincipal); loginContext = new LoginContext(LOGIN_CONTEXT_NAME, azureCallBackHandler); loginContext.login(); Subject subject = loginContext.getSubject(); CallerPrincipalCallback callerCallBack = new CallerPrincipalCallback(clientSubject, userPrincipal); GroupPrincipalCallback groupPrincipalCallback = new GroupPrincipalCallback(clientSubject, groups); Callback[] callbacks = new Callback[] { callerCallBack, groupPrincipalCallback }; handler.handle(callbacks); return AuthStatus.SUCCESS; } else { // ????? // if authentication failed AuthenticationErrorResponse authErrorResponse = (AuthenticationErrorResponse) authResponse; CallerPrincipalCallback callerCallBack = new CallerPrincipalCallback(clientSubject, (Principal) null); GroupPrincipalCallback groupPrincipalCallback = new GroupPrincipalCallback(clientSubject, null); Callback[] callbacks = new Callback[] { callerCallBack, groupPrincipalCallback }; handler.handle(callbacks); return AuthStatus.FAILURE; } } catch (Throwable ex) { CallerPrincipalCallback callerCallBack = new CallerPrincipalCallback(clientSubject, (Principal) null); GroupPrincipalCallback groupPrincipalCallback = new GroupPrincipalCallback(clientSubject, null); Callback[] callbacks = new Callback[] { callerCallBack, groupPrincipalCallback }; try { handler.handle(callbacks); } catch (IOException | UnsupportedCallbackException ex1) { LOGGER.log(Level.SEVERE, null, ex1); } LOGGER.log(Level.SEVERE, null, ex); return AuthStatus.FAILURE; } } private static Client jaxrsClient; private Client getConnectionFactory() { if (jaxrsClient == null) { jaxrsClient = ClientBuilder.newClient().register( (new JacksonJaxbJsonProvider(new ObjectMapper(), JacksonJaxbJsonProvider.DEFAULT_ANNOTATIONS))) .register(JacksonFeature.class); return jaxrsClient; } else { return jaxrsClient; } } private String[] getGroupList(AzureADUserPrincipal userPrincipal) { String authString = "Bearer " + userPrincipal.getAuthenticationResult().getAccessToken(); System.setProperty("sun.net.http.allowRestrictedHeaders", "true"); String graphURL = String.format("https://%s/%s/users/%s/getMemberGroups", graphServer, tenant, userPrincipal.getName()); JsonObject model = Json.createObjectBuilder().add("securityEnabledOnly", "false").build(); StringWriter stWriter = new StringWriter(); try (JsonWriter jsonWriter = Json.createWriter(stWriter)) { jsonWriter.writeObject(model); } String jsonData = stWriter.toString(); Future<Response> response = getConnectionFactory().target(graphURL).request().header("Host", graphServer) .header("Accept", "application/json, text/plain, */*").header("Content-Type", "application/json") .header("api-version", "1.6").header("Authorization", authString).async() .post(Entity.json(jsonData)); try { ADUserMemberOfGroups memberOfGrups; memberOfGrups = response.get().readEntity(ADUserMemberOfGroups.class); LOGGER.log(Level.INFO, memberOfGrups.toString()); return memberOfGrups.getValue(); } catch (InterruptedException | ExecutionException ex) { Logger.getLogger(AzureADServerAuthModule.class.getName()).log(Level.SEVERE, null, ex); return null; } } /* ??? */ /* get the access token from the refresh token */ private AuthenticationResult getAccessTokenFromRefreshToken(String refreshToken, String currentUri) throws Throwable { AuthenticationContext context; AuthenticationResult result; ExecutorService service = null; try { service = Executors.newFixedThreadPool(1); context = new AuthenticationContext(authority + tenant + "/", true, service); Future<AuthenticationResult> future = context.acquireTokenByRefreshToken(refreshToken, new ClientCredential(clientId, secretKey), null, null); result = future.get(); } catch (ExecutionException e) { throw e.getCause(); } finally { if (service != null) { service.shutdown(); } } if (result == null) { throw new ServiceUnavailableException("authentication result was null"); } return result; } /* ??*/ /* get the access token */ private AuthenticationResult getAccessToken(AuthorizationCode authorizationCode, String currentUri) throws Throwable { String authCode = authorizationCode.getValue(); ClientCredential credential = new ClientCredential(clientId, secretKey); AuthenticationContext context; AuthenticationResult result; ExecutorService service = null; try { service = Executors.newFixedThreadPool(1); context = new AuthenticationContext(authority + tenant + "/", true, service); Future<AuthenticationResult> future = context.acquireTokenByAuthorizationCode(authCode, new URI(currentUri), credential, null); result = future.get(); } catch (ExecutionException e) { throw e.getCause(); } finally { if (service != null) { service.shutdown(); } } if (result == null) { throw new ServiceUnavailableException("authentication result was null"); } return result; } /* HTTP ?? */ /* set authentication information in the session */ private void setSessionPrincipal(HttpServletRequest httpRequest, AzureADUserPrincipal principal) throws Exception { httpRequest.getSession().setAttribute(PRINCIPAL_SESSION_NAME, principal); } /* HTTP ????? */ /* get the authentication result from the HTTP session */ public boolean getSessionPrincipal(HttpServletRequest request) { return request.getUserPrincipal() != null; // return (AzureADUserPrincipal) request.getSession().getAttribute(PRINCIPAL_SESSION_NAME); } private void setSessionSubject(HttpServletRequest httpRequest, final Subject clientSubject) { if (clientSubject == null) { return; } httpRequest.getSession().setAttribute(SAVED_SUBJECT, clientSubject); LOGGER.log(Level.FINE, "Saved subject {0}", clientSubject); } private Subject getSessionSubject(HttpServletRequest httpRequest) { return (Subject) httpRequest.getSession().getAttribute(SAVED_SUBJECT); } /* URL ?? */ /* get the redirect URL */ private String getRedirectUrl(String currentUri) throws UnsupportedEncodingException { String redirectUrl = authority + this.tenant + "/oauth2/authorize?response_type=code%20id_token&scope=openid&response_mode=form_post&redirect_uri=" + URLEncoder.encode(currentUri, "UTF-8") + "&client_id=" + clientId + "&resource=https%3a%2f%2fgraph.windows.net" + "&nonce=" + UUID.randomUUID() + "&site_id=500879"; return redirectUrl; } /* HTTP ???????? */ /* check whether the HTTP session is authenticated or not */ // TODO does not seem to match method public boolean isRedirectedRequestFromAuthServer(HttpServletRequest httpRequest, Map<String, String> params) { return httpRequest.getMethod().equalsIgnoreCase("POST") && (httpRequest.getParameterMap().containsKey(ERROR) || httpRequest.getParameterMap().containsKey(ID_TOKEN) || httpRequest.getParameterMap().containsKey(CODE)); } /* ????????? */ /* check whether authentication data is included or not */ public boolean containsAuthenticationData(HttpServletRequest httpRequest) { // System.out.println("containsAuthenticationData ??" + httpRequest.getUserPrincipal().getName()); Map<String, String[]> map = httpRequest.getParameterMap(); return httpRequest.getMethod().equalsIgnoreCase("POST") && (httpRequest.getParameterMap().containsKey(ERROR) || httpRequest.getParameterMap().containsKey(ID_TOKEN) || httpRequest.getParameterMap().containsKey(CODE)); } /* ? URI ? */ /* get the request URI */ private String getCurrentUri(HttpServletRequest request) { String scheme = request.getScheme(); int serverPort = request.getServerPort(); String portNumberString = ""; if (!((scheme.equals("http") && serverPort == 80) || (scheme.equals("https") && serverPort == 443))) { portNumberString = ":" + String.valueOf(serverPort); } String uri = scheme + "://" + request.getServerName() + portNumberString + request.getRequestURI(); return uri; } }