org.apache.drill.exec.rpc.security.ServerAuthenticationHandler.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.drill.exec.rpc.security.ServerAuthenticationHandler.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.drill.exec.rpc.security;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.protobuf.ByteString;
import com.google.protobuf.Internal.EnumLite;
import com.google.protobuf.InvalidProtocolBufferException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import org.apache.drill.exec.proto.UserBitShared.SaslMessage;
import org.apache.drill.exec.proto.UserBitShared.SaslStatus;
import org.apache.drill.exec.rpc.RequestHandler;
import org.apache.drill.exec.rpc.Response;
import org.apache.drill.exec.rpc.ResponseSender;
import org.apache.drill.exec.rpc.RpcException;
import org.apache.drill.exec.rpc.ServerConnection;
import org.apache.hadoop.security.UserGroupInformation;

import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.PrivilegedExceptionAction;
import java.util.EnumMap;
import java.util.Map;

import static com.google.common.base.Preconditions.checkNotNull;

/**
 * Handles SASL exchange, on the server-side.
 *
 * @param <S> Server connection type
 * @param <T> RPC type
 */
public class ServerAuthenticationHandler<S extends ServerConnection<S>, T extends EnumLite>
        implements RequestHandler<S> {
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory
            .getLogger(ServerAuthenticationHandler.class);

    private static final ImmutableMap<SaslStatus, SaslResponseProcessor> RESPONSE_PROCESSORS;

    static {
        final Map<SaslStatus, SaslResponseProcessor> map = new EnumMap<>(SaslStatus.class);
        map.put(SaslStatus.SASL_START, new SaslStartProcessor());
        map.put(SaslStatus.SASL_IN_PROGRESS, new SaslInProgressProcessor());
        map.put(SaslStatus.SASL_SUCCESS, new SaslSuccessProcessor());
        map.put(SaslStatus.SASL_FAILED, new SaslFailedProcessor());
        RESPONSE_PROCESSORS = Maps.immutableEnumMap(map);
    }

    private final RequestHandler<S> requestHandler;
    private final int saslRequestTypeValue;
    private final T saslResponseType;

    public ServerAuthenticationHandler(final RequestHandler<S> requestHandler, final int saslRequestTypeValue,
            final T saslResponseType) {
        this.requestHandler = requestHandler;
        this.saslRequestTypeValue = saslRequestTypeValue;
        this.saslResponseType = saslResponseType;
    }

    @Override
    public void handle(S connection, int rpcType, ByteBuf pBody, ByteBuf dBody, ResponseSender sender)
            throws RpcException {
        final String remoteAddress = connection.getRemoteAddress().toString();

        // exchange involves server "challenges" and client "responses" (initiated by client)
        if (saslRequestTypeValue == rpcType) {
            final SaslMessage saslResponse;
            try {
                saslResponse = SaslMessage.PARSER.parseFrom(new ByteBufInputStream(pBody));
            } catch (final InvalidProtocolBufferException e) {
                handleAuthFailure(remoteAddress, sender, e, saslResponseType);
                return;
            }

            logger.trace("Received SASL message {} from {}", saslResponse.getStatus(), remoteAddress);
            final SaslResponseProcessor processor = RESPONSE_PROCESSORS.get(saslResponse.getStatus());
            if (processor == null) {
                logger.info("Unknown message type from client from {}. Will stop authentication.", remoteAddress);
                handleAuthFailure(remoteAddress, sender, new SaslException("Received unexpected message"),
                        saslResponseType);
                return;
            }

            final SaslResponseContext<S, T> context = new SaslResponseContext<>(saslResponse, connection,
                    remoteAddress, sender, requestHandler, saslResponseType);
            try {
                processor.process(context);
            } catch (final Exception e) {
                handleAuthFailure(remoteAddress, sender, e, saslResponseType);
            }
        } else {

            // this handler only handles messages of SASL_MESSAGE_VALUE type

            // the response type for this request type is likely known from UserRpcConfig,
            // but the client should not be making any requests before authenticating.
            // drop connection
            throw new RpcException(String.format(
                    "Request of type %d is not allowed without authentication. "
                            + "Client on %s must authenticate before making requests. Connection dropped.",
                    rpcType, remoteAddress));
        }
    }

    private static class SaslResponseContext<S extends ServerConnection<S>, T extends EnumLite> {

        final SaslMessage saslResponse;
        final S connection;
        final String remoteAddress;
        final ResponseSender sender;
        final RequestHandler<S> requestHandler;
        final T saslResponseType;

        SaslResponseContext(SaslMessage saslResponse, S connection, String remoteAddress, ResponseSender sender,
                RequestHandler<S> requestHandler, T saslResponseType) {
            this.saslResponse = checkNotNull(saslResponse);
            this.connection = checkNotNull(connection);
            this.remoteAddress = checkNotNull(remoteAddress);
            this.sender = checkNotNull(sender);
            this.requestHandler = checkNotNull(requestHandler);
            this.saslResponseType = checkNotNull(saslResponseType);
        }
    }

    private interface SaslResponseProcessor {

        /**
         * Process response from client, and if there are no exceptions, send response using
         * {@link SaslResponseContext#sender}. Otherwise, throw the exception.
         *
         * @param context response context
         */
        <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context)
                throws Exception;

    }

    private static class SaslStartProcessor implements SaslResponseProcessor {

        @Override
        public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context)
                throws Exception {
            context.connection.initSaslServer(context.saslResponse.getMechanism());

            // assume #evaluateResponse must be called at least once
            RESPONSE_PROCESSORS.get(SaslStatus.SASL_IN_PROGRESS).process(context);
        }
    }

    private static class SaslInProgressProcessor implements SaslResponseProcessor {

        @Override
        public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context)
                throws Exception {
            final SaslMessage.Builder challenge = SaslMessage.newBuilder();
            final SaslServer saslServer = context.connection.getSaslServer();

            final byte[] challengeBytes = evaluateResponse(saslServer,
                    context.saslResponse.getData().toByteArray());

            if (saslServer.isComplete()) {
                challenge.setStatus(SaslStatus.SASL_SUCCESS);
                if (challengeBytes != null) {
                    challenge.setData(ByteString.copyFrom(challengeBytes));
                }

                handleSuccess(context, challenge, saslServer);
            } else {
                challenge.setStatus(SaslStatus.SASL_IN_PROGRESS).setData(ByteString.copyFrom(challengeBytes));
                context.sender.send(new Response(context.saslResponseType, challenge.build()));
            }
        }
    }

    // only when client succeeds first
    private static class SaslSuccessProcessor implements SaslResponseProcessor {

        @Override
        public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context)
                throws Exception {
            // at this point, #isComplete must be false; so try once, fail otherwise
            final SaslServer saslServer = context.connection.getSaslServer();

            evaluateResponse(saslServer, context.saslResponse.getData().toByteArray()); // discard challenge

            if (saslServer.isComplete()) {
                final SaslMessage.Builder challenge = SaslMessage.newBuilder();
                challenge.setStatus(SaslStatus.SASL_SUCCESS);

                handleSuccess(context, challenge, saslServer);
            } else {
                logger.info("Failed to authenticate client from {}", context.remoteAddress);
                throw new SaslException(
                        "Client allegedly succeeded authentication, but server did not. Suspicious?");
            }
        }
    }

    private static class SaslFailedProcessor implements SaslResponseProcessor {

        @Override
        public <S extends ServerConnection<S>, T extends EnumLite> void process(SaslResponseContext<S, T> context)
                throws Exception {
            logger.info("Client from {} failed authentication graciously, and does not want to continue.",
                    context.remoteAddress);
            throw new SaslException("Client graciously failed authentication");
        }
    }

    private static byte[] evaluateResponse(final SaslServer saslServer, final byte[] responseBytes)
            throws SaslException {
        try {
            return UserGroupInformation.getLoginUser().doAs(new PrivilegedExceptionAction<byte[]>() {
                @Override
                public byte[] run() throws Exception {
                    return saslServer.evaluateResponse(responseBytes);
                }
            });
        } catch (final UndeclaredThrowableException e) {
            throw new SaslException(String.format("Unexpected failure trying to authenticate using %s",
                    saslServer.getMechanismName()), e.getCause());
        } catch (final IOException | InterruptedException e) {
            if (e instanceof SaslException) {
                throw (SaslException) e;
            } else {
                throw new SaslException(String.format("Unexpected failure trying to authenticate using %s",
                        saslServer.getMechanismName()), e);
            }
        }
    }

    private static <S extends ServerConnection<S>, T extends EnumLite> void handleSuccess(
            final SaslResponseContext<S, T> context, final SaslMessage.Builder challenge,
            final SaslServer saslServer) throws IOException {
        context.connection.changeHandlerTo(context.requestHandler);
        context.connection.finalizeSaslSession();
        context.sender.send(new Response(context.saslResponseType, challenge.build()));

        // setup security layers here..

        if (logger.isTraceEnabled()) {
            logger.trace("Authenticated {} successfully using {} from {}", saslServer.getAuthorizationID(),
                    saslServer.getMechanismName(), context.remoteAddress);
        }
    }

    private static final SaslMessage SASL_FAILED_MESSAGE = SaslMessage.newBuilder()
            .setStatus(SaslStatus.SASL_FAILED).build();

    private static <T extends EnumLite> void handleAuthFailure(final String remoteAddress,
            final ResponseSender sender, final Exception e, final T saslResponseType) throws RpcException {
        logger.debug("Authentication failed from client {} due to {}", remoteAddress, e);

        // inform the client that authentication failed, and no more
        sender.send(new Response(saslResponseType, SASL_FAILED_MESSAGE));

        // drop connection
        throw new RpcException(e);
    }
}