org.apache.thrift.transport.TSaslServerTransport.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.thrift.transport.TSaslServerTransport.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.thrift.transport;

import java.lang.ref.WeakReference;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.WeakHashMap;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;

import org.apache.thrift.transport.sasl.NegotiationStatus;
import org.apache.thrift.transport.sasl.TSaslServerDefinition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Wraps another Thrift <code>TTransport</code>, but performs SASL server
 * negotiation on the call to <code>open()</code>. This class will wrap ensuing
 * communication over it, if a SASL QOP is negotiated with the other party.
 */
public class TSaslServerTransport extends TSaslTransport {

    private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class);

    /**
     * Mapping from SASL mechanism name -> all the parameters required to
     * instantiate a SASL server.
     */
    private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();

    /**
     * Uses the given underlying transport. Assumes that addServerDefinition is
     * called later.
     *
     * @param transport
     *          Transport underlying this one.
     */
    public TSaslServerTransport(TTransport transport) {
        super(transport);
    }

    /**
     * Creates a <code>SaslServer</code> using the given SASL-specific parameters.
     * See the Java documentation for <code>Sasl.createSaslServer</code> for the
     * details of the parameters.
     *
     * @param transport
     *          The underlying Thrift transport.
     */
    public TSaslServerTransport(String mechanism, String protocol, String serverName, Map<String, String> props,
            CallbackHandler cbh, TTransport transport) {
        super(transport);
        addServerDefinition(mechanism, protocol, serverName, props, cbh);
    }

    private TSaslServerTransport(Map<String, TSaslServerDefinition> serverDefinitionMap, TTransport transport) {
        super(transport);
        this.serverDefinitionMap.putAll(serverDefinitionMap);
    }

    /**
     * Add a supported server definition to this transport. See the Java
     * documentation for <code>Sasl.createSaslServer</code> for the details of the
     * parameters.
     */
    public void addServerDefinition(String mechanism, String protocol, String serverName, Map<String, String> props,
            CallbackHandler cbh) {
        serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, props, cbh));
    }

    @Override
    protected SaslRole getRole() {
        return SaslRole.SERVER;
    }

    /**
     * Performs the server side of the initial portion of the Thrift SASL protocol.
     * Receives the initial response from the client, creates a SASL server using
     * the mechanism requested by the client (if this server supports it), and
     * sends the first challenge back to the client.
     */
    @Override
    protected void handleSaslStartMessage() throws TTransportException, SaslException {
        SaslResponse message = receiveSaslMessage();

        LOGGER.debug("Received start message with status {}", message.status);
        if (message.status != NegotiationStatus.START) {
            throw sendAndThrowMessage(NegotiationStatus.ERROR,
                    "Expecting START status, received " + message.status);
        }

        // Get the mechanism name.
        String mechanismName = new String(message.payload, StandardCharsets.UTF_8);
        TSaslServerDefinition serverDefinition = serverDefinitionMap.get(mechanismName);
        LOGGER.debug("Received mechanism name '{}'", mechanismName);

        if (serverDefinition == null) {
            throw sendAndThrowMessage(NegotiationStatus.BAD, "Unsupported mechanism type " + mechanismName);
        }
        SaslServer saslServer = Sasl.createSaslServer(serverDefinition.mechanism, serverDefinition.protocol,
                serverDefinition.serverName, serverDefinition.props, serverDefinition.cbh);
        setSaslServer(saslServer);
    }

    /**
     * <code>TTransportFactory</code> to create
     * <code>TSaslServerTransports</code>. Ensures that a given
     * underlying <code>TTransport</code> instance receives the same
     * <code>TSaslServerTransport</code>. This is kind of an awful hack to work
     * around the fact that Thrift is designed assuming that
     * <code>TTransport</code> instances are stateless, and thus the existing
     * <code>TServers</code> use different <code>TTransport</code> instances for
     * input and output.
     */
    public static class Factory extends TTransportFactory {

        /**
         * This is the implementation of the awful hack described above.
         * <code>WeakHashMap</code> is used to ensure that we don't leak memory.
         */
        private static Map<TTransport, WeakReference<TSaslServerTransport>> transportMap = Collections
                .synchronizedMap(new WeakHashMap<TTransport, WeakReference<TSaslServerTransport>>());

        /**
         * Mapping from SASL mechanism name -> all the parameters required to
         * instantiate a SASL server.
         */
        private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();

        /**
         * Create a new Factory. Assumes that <code>addServerDefinition</code> will
         * be called later.
         */
        public Factory() {
            super();
        }

        /**
         * Create a new <code>Factory</code>, initially with the single server
         * definition given. You may still call <code>addServerDefinition</code>
         * later. See the Java documentation for <code>Sasl.createSaslServer</code>
         * for the details of the parameters.
         */
        public Factory(String mechanism, String protocol, String serverName, Map<String, String> props,
                CallbackHandler cbh) {
            super();
            addServerDefinition(mechanism, protocol, serverName, props, cbh);
        }

        /**
         * Add a supported server definition to the transports created by this
         * factory. See the Java documentation for
         * <code>Sasl.createSaslServer</code> for the details of the parameters.
         */
        public void addServerDefinition(String mechanism, String protocol, String serverName,
                Map<String, String> props, CallbackHandler cbh) {
            serverDefinitionMap.put(mechanism,
                    new TSaslServerDefinition(mechanism, protocol, serverName, props, cbh));
        }

        /**
         * Get a new <code>TSaslServerTransport</code> instance, or reuse the
         * existing one if a <code>TSaslServerTransport</code> has already been
         * created before using the given <code>TTransport</code> as an underlying
         * transport. This ensures that a given underlying transport instance
         * receives the same <code>TSaslServerTransport</code>.
         */
        @Override
        public TTransport getTransport(TTransport base) {
            WeakReference<TSaslServerTransport> ret = transportMap.get(base);
            if (ret == null || ret.get() == null) {
                LOGGER.debug("transport map does not contain key", base);
                ret = new WeakReference<TSaslServerTransport>(new TSaslServerTransport(serverDefinitionMap, base));
                try {
                    ret.get().open();
                } catch (TTransportException e) {
                    LOGGER.debug("failed to open server transport", e);
                    throw new RuntimeException(e);
                }
                transportMap.put(base, ret); // No need for putIfAbsent().
                                             // Concurrent calls to getTransport() will pass in different TTransports.
            } else {
                LOGGER.debug("transport map does contain key {}", base);
            }
            return ret.get();
        }
    }
}