Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.drill.exec.rpc.security; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.protobuf.ByteString; import com.google.protobuf.Internal.EnumLite; import com.google.protobuf.InvalidProtocolBufferException; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; import org.apache.drill.exec.proto.UserBitShared.SaslMessage; import org.apache.drill.exec.proto.UserBitShared.SaslStatus; import org.apache.drill.exec.rpc.RequestHandler; import org.apache.drill.exec.rpc.Response; import org.apache.drill.exec.rpc.ResponseSender; import org.apache.drill.exec.rpc.RpcException; import org.apache.drill.exec.rpc.ServerConnection; import org.apache.hadoop.security.UserGroupInformation; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import java.io.IOException; import java.lang.reflect.UndeclaredThrowableException; import java.security.PrivilegedExceptionAction; import java.util.EnumMap; import java.util.Map; import static com.google.common.base.Preconditions.checkNotNull; /** * Handles SASL exchange, on the server-side. * * @param <S> Server connection type * @param <T> RPC type */ public class ServerAuthenticationHandler<S extends ServerConnection<S>, T extends EnumLite> implements RequestHandler<S> { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory .getLogger(ServerAuthenticationHandler.class); private static final ImmutableMap<SaslStatus, SaslResponseProcessor> RESPONSE_PROCESSORS; static { final Map<SaslStatus, SaslResponseProcessor> map = new EnumMap<>(SaslStatus.class); map.put(SaslStatus.SASL_START, new SaslStartProcessor()); map.put(SaslStatus.SASL_IN_PROGRESS, new SaslInProgressProcessor()); map.put(SaslStatus.SASL_SUCCESS, new SaslSuccessProcessor()); map.put(SaslStatus.SASL_FAILED, new SaslFailedProcessor()); RESPONSE_PROCESSORS = Maps.immutableEnumMap(map); } private final RequestHandler<S> requestHandler; private final int saslRequestTypeValue; private final T saslResponseType; public ServerAuthenticationHandler(final RequestHandler<S> requestHandler, final int saslRequestTypeValue, final T saslResponseType) { this.requestHandler = requestHandler; this.saslRequestTypeValue = saslRequestTypeValue; this.saslResponseType = saslResponseType; } @Override public void handle(S connection, int rpcType, ByteBuf pBody, ByteBuf dBody, ResponseSender sender) throws RpcException { final String remoteAddress = connection.getRemoteAddress().toString(); // exchange involves server "challenges" and client "responses" (initiated by client) if (saslRequestTypeValue == rpcType) { final SaslMessage saslResponse; try { saslResponse = SaslMessage.PARSER.parseFrom(new ByteBufInputStream(pBody)); } catch (final InvalidProtocolBufferException e) { handleAuthFailure(remoteAddress, sender, e, saslResponseType); return; } logger.trace("Received SASL message {} from {}", saslResponse.getStatus(), remoteAddress); final SaslResponseProcessor processor = RESPONSE_PROCESSORS.get(saslResponse.getStatus()); if (processor == null) { logger.info("Unknown message type from client from {}. Will stop authentication.", remoteAddress); handleAuthFailure(remoteAddress, sender, new SaslException("Received unexpected message"), saslResponseType); return; } final SaslResponseContext<S, T> context = new SaslResponseContext<>(saslResponse, connection, remoteAddress, sender, requestHandler, saslResponseType); try { processor.process(context); } catch (final Exception e) { handleAuthFailure(remoteAddress, sender, e, saslResponseType); } } else { // this handler only handles messages of SASL_MESSAGE_VALUE type // the response type for this request type is likely known from UserRpcConfig, // but the client should not be making any requests before authenticating. // drop connection throw new RpcException(String.format( "Request of type %d is not allowed without authentication. " + "Client on %s must authenticate before making requests. Connection dropped.", rpcType, remoteAddress)); } } private static class SaslResponseContext<S extends ServerConnection<S>, T extends EnumLite> { final SaslMessage saslResponse; final S connection; final String remoteAddress; final ResponseSender sender; final RequestHandler<S> requestHandler; final T saslResponseType; SaslResponseContext(SaslMessage saslResponse, S connection, String remoteAddress, ResponseSender sender, RequestHandler<S> requestHandler, T saslResponseType) { this.saslResponse = checkNotNull(saslResponse); this.connection = checkNotNull(connection); this.remoteAddress = checkNotNull(remoteAddress); this.sender = checkNotNull(sender); this.requestHandler = checkNotNull(requestHandler); this.saslResponseType = checkNotNull(saslResponseType); } } private interface SaslResponseProcessor { /** * Process response from client, and if there are no exceptions, send response using * {@link SaslResponseContext#sender}. Otherwise, throw the exception. * * @param context response context */ <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context) throws Exception; } private static class SaslStartProcessor implements SaslResponseProcessor { @Override public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context) throws Exception { context.connection.initSaslServer(context.saslResponse.getMechanism()); // assume #evaluateResponse must be called at least once RESPONSE_PROCESSORS.get(SaslStatus.SASL_IN_PROGRESS).process(context); } } private static class SaslInProgressProcessor implements SaslResponseProcessor { @Override public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context) throws Exception { final SaslMessage.Builder challenge = SaslMessage.newBuilder(); final SaslServer saslServer = context.connection.getSaslServer(); final byte[] challengeBytes = evaluateResponse(saslServer, context.saslResponse.getData().toByteArray()); if (saslServer.isComplete()) { challenge.setStatus(SaslStatus.SASL_SUCCESS); if (challengeBytes != null) { challenge.setData(ByteString.copyFrom(challengeBytes)); } handleSuccess(context, challenge, saslServer); } else { challenge.setStatus(SaslStatus.SASL_IN_PROGRESS).setData(ByteString.copyFrom(challengeBytes)); context.sender.send(new Response(context.saslResponseType, challenge.build())); } } } // only when client succeeds first private static class SaslSuccessProcessor implements SaslResponseProcessor { @Override public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context) throws Exception { // at this point, #isComplete must be false; so try once, fail otherwise final SaslServer saslServer = context.connection.getSaslServer(); evaluateResponse(saslServer, context.saslResponse.getData().toByteArray()); // discard challenge if (saslServer.isComplete()) { final SaslMessage.Builder challenge = SaslMessage.newBuilder(); challenge.setStatus(SaslStatus.SASL_SUCCESS); handleSuccess(context, challenge, saslServer); } else { logger.info("Failed to authenticate client from {}", context.remoteAddress); throw new SaslException( "Client allegedly succeeded authentication, but server did not. Suspicious?"); } } } private static class SaslFailedProcessor implements SaslResponseProcessor { @Override public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context) throws Exception { logger.info("Client from {} failed authentication graciously, and does not want to continue.", context.remoteAddress); throw new SaslException("Client graciously failed authentication"); } } private static byte[] evaluateResponse(final SaslServer saslServer, final byte[] responseBytes) throws SaslException { try { return UserGroupInformation.getLoginUser().doAs(new PrivilegedExceptionAction<byte[]>() { @Override public byte[] run() throws Exception { return saslServer.evaluateResponse(responseBytes); } }); } catch (final UndeclaredThrowableException e) { throw new SaslException(String.format("Unexpected failure trying to authenticate using %s", saslServer.getMechanismName()), e.getCause()); } catch (final IOException | InterruptedException e) { if (e instanceof SaslException) { throw (SaslException) e; } else { throw new SaslException(String.format("Unexpected failure trying to authenticate using %s", saslServer.getMechanismName()), e); } } } private static <S extends ServerConnection<S>, T extends EnumLite> void handleSuccess( final SaslResponseContext<S, T> context, final SaslMessage.Builder challenge, final SaslServer saslServer) throws IOException { context.connection.changeHandlerTo(context.requestHandler); context.connection.finalizeSaslSession(); context.sender.send(new Response(context.saslResponseType, challenge.build())); // setup security layers here.. if (logger.isTraceEnabled()) { logger.trace("Authenticated {} successfully using {} from {}", saslServer.getAuthorizationID(), saslServer.getMechanismName(), context.remoteAddress); } } private static final SaslMessage SASL_FAILED_MESSAGE = SaslMessage.newBuilder() .setStatus(SaslStatus.SASL_FAILED).build(); private static <T extends EnumLite> void handleAuthFailure(final String remoteAddress, final ResponseSender sender, final Exception e, final T saslResponseType) throws RpcException { logger.debug("Authentication failed from client {} due to {}", remoteAddress, e); // inform the client that authentication failed, and no more sender.send(new Response(saslResponseType, SASL_FAILED_MESSAGE)); // drop connection throw new RpcException(e); } }