org.apache.hadoop.ipc.TestSaslRPC.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.ipc.TestSaslRPC.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.ipc;

import static org.apache.hadoop.fs.CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION;
import static org.junit.Assert.*;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.security.PrivilegedExceptionAction;
import java.util.Set;

import junit.framework.Assert;

import org.apache.commons.logging.*;
import org.apache.commons.logging.impl.Log4JLogger;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.SecretManager.InvalidToken;
import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier;
import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenSelector;
import org.apache.hadoop.security.SaslInputStream;
import org.apache.hadoop.security.SaslRpcClient;
import org.apache.hadoop.security.SaslRpcServer;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.TestUserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;

import org.apache.log4j.Level;
import org.junit.BeforeClass;
import org.junit.Test;

/** Unit tests for using Sasl over RPC. */
public class TestSaslRPC {
    private static final String ADDRESS = "localhost";

    public static final Log LOG = LogFactory.getLog(TestSaslRPC.class);

    static final String ERROR_MESSAGE = "Token is invalid";
    static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
    static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
    static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
    static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";

    private static Configuration conf;
    static {
        conf = new Configuration();
        conf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos");
        UserGroupInformation.setConfiguration(conf);
    }

    static {
        ((Log4JLogger) Client.LOG).getLogger().setLevel(Level.ALL);
        ((Log4JLogger) Server.LOG).getLogger().setLevel(Level.ALL);
        ((Log4JLogger) SaslRpcClient.LOG).getLogger().setLevel(Level.ALL);
        ((Log4JLogger) SaslRpcServer.LOG).getLogger().setLevel(Level.ALL);
        ((Log4JLogger) SaslInputStream.LOG).getLogger().setLevel(Level.ALL);
        ((Log4JLogger) SecurityUtil.LOG).getLogger().setLevel(Level.ALL);
    }

    static Method setUseIpMethod;

    @BeforeClass
    public static void exposeUseIp() throws Exception {
        setUseIpMethod = SecurityUtil.class.getDeclaredMethod("setTokenServiceUseIp", boolean.class);
        setUseIpMethod.setAccessible(true);
    }

    private void setTokenServiceUseIp(boolean flag) throws Exception {
        setUseIpMethod.invoke(setUseIpMethod.getClass(), flag);
    }

    public static class TestTokenIdentifier extends AbstractDelegationTokenIdentifier {
        private Text tokenid;
        private Text realUser;
        final static Text KIND_NAME = new Text("test.token");

        public TestTokenIdentifier() {
            this(new Text(), new Text());
        }

        public TestTokenIdentifier(Text tokenid) {
            this(tokenid, new Text());
        }

        public TestTokenIdentifier(Text tokenid, Text realUser) {
            this.tokenid = tokenid == null ? new Text() : tokenid;
            this.realUser = realUser == null ? new Text() : realUser;
        }

        @Override
        public Text getKind() {
            return KIND_NAME;
        }

        @Override
        public UserGroupInformation getUser() {
            if ("".equals(realUser.toString())) {
                return UserGroupInformation.createRemoteUser(tokenid.toString());
            } else {
                UserGroupInformation realUgi = UserGroupInformation.createRemoteUser(realUser.toString());
                return UserGroupInformation.createProxyUser(tokenid.toString(), realUgi);
            }
        }

        @Override
        public void readFields(DataInput in) throws IOException {
            tokenid.readFields(in);
            realUser.readFields(in);
        }

        @Override
        public void write(DataOutput out) throws IOException {
            tokenid.write(out);
            realUser.write(out);
        }
    }

    public static class TestTokenSecretManager extends SecretManager<TestTokenIdentifier> {
        public byte[] createPassword(TestTokenIdentifier id) {
            return id.getBytes();
        }

        public byte[] retrievePassword(TestTokenIdentifier id) throws InvalidToken {
            return id.getBytes();
        }

        public TestTokenIdentifier createIdentifier() {
            return new TestTokenIdentifier();
        }
    }

    public static class BadTokenSecretManager extends TestTokenSecretManager {

        public byte[] retrievePassword(TestTokenIdentifier id) throws InvalidToken {
            throw new InvalidToken(ERROR_MESSAGE);
        }
    }

    public static class TestTokenSelector extends AbstractDelegationTokenSelector<TestTokenIdentifier> {
        TestTokenSelector() {
            super(TestTokenIdentifier.KIND_NAME);
        }
    }

    @KerberosInfo(serverPrincipal = SERVER_PRINCIPAL_KEY)
    @TokenInfo(TestTokenSelector.class)
    public interface TestSaslProtocol extends TestRPC.TestProtocol {
        public AuthenticationMethod getAuthMethod() throws IOException;
    }

    public static class TestSaslImpl extends TestRPC.TestImpl implements TestSaslProtocol {
        public AuthenticationMethod getAuthMethod() throws IOException {
            return UserGroupInformation.getCurrentUser().getAuthenticationMethod();
        }
    }

    @Test
    public void testDigestRpc() throws Exception {
        TestTokenSecretManager sm = new TestTokenSecretManager();
        final Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);

        doDigestRpc(server, sm);
    }

    @Test
    public void testSecureToInsecureRpc() throws Exception {
        Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, null);
        server.disableSecurity();
        TestTokenSecretManager sm = new TestTokenSecretManager();
        doDigestRpc(server, sm);
    }

    @Test
    public void testErrorMessage() throws Exception {
        BadTokenSecretManager sm = new BadTokenSecretManager();
        final Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);

        boolean succeeded = false;
        try {
            doDigestRpc(server, sm);
        } catch (RemoteException e) {
            LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage());
            assertTrue(ERROR_MESSAGE.equals(e.getLocalizedMessage()));
            assertTrue(e.unwrapRemoteException() instanceof InvalidToken);
            succeeded = true;
        }
        assertTrue(succeeded);
    }

    private void doDigestRpc(Server server, TestTokenSecretManager sm) throws Exception {
        server.start();

        final UserGroupInformation current = UserGroupInformation.getCurrentUser();
        final InetSocketAddress addr = NetUtils.getConnectAddress(server);
        TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current.getUserName()));
        Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm);
        SecurityUtil.setTokenService(token, addr);
        LOG.info("Service IP address for token is " + token.getService());
        current.addToken(token);

        TestSaslProtocol proxy = null;
        try {
            proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, conf);
            proxy.ping();
        } finally {
            server.stop();
            if (proxy != null) {
                RPC.stopProxy(proxy);
            }
        }
    }

    @Test
    public void testGetRemotePrincipal() throws Exception {
        try {
            Configuration newConf = new Configuration(conf);
            newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
            ConnectionId remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0), TestSaslProtocol.class,
                    null, newConf);
            assertEquals(SERVER_PRINCIPAL_1, remoteId.getServerPrincipal());
            // this following test needs security to be off
            newConf.set(HADOOP_SECURITY_AUTHENTICATION, "simple");
            UserGroupInformation.setConfiguration(newConf);
            remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0), TestSaslProtocol.class, null,
                    newConf);
            assertEquals("serverPrincipal should be null when security is turned off", null,
                    remoteId.getServerPrincipal());
        } finally {
            // revert back to security is on
            UserGroupInformation.setConfiguration(conf);
        }
    }

    @Test
    public void testPerConnectionConf() throws Exception {
        TestTokenSecretManager sm = new TestTokenSecretManager();
        final Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
        server.start();
        final UserGroupInformation current = UserGroupInformation.getCurrentUser();
        final InetSocketAddress addr = NetUtils.getConnectAddress(server);
        TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current.getUserName()));
        Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm);
        SecurityUtil.setTokenService(token, addr);
        LOG.info("Service IP address for token is " + token.getService());
        current.addToken(token);

        Configuration newConf = new Configuration(conf);
        newConf.set("hadoop.rpc.socket.factory.class.default", "");
        newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);

        TestSaslProtocol proxy1 = null;
        TestSaslProtocol proxy2 = null;
        TestSaslProtocol proxy3 = null;
        try {
            proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr,
                    newConf);
            Client client = RPC.getClient(conf);
            Set<ConnectionId> conns = client.getConnectionIds();
            assertEquals("number of connections in cache is wrong", 1, conns.size());
            // same conf, connection should be re-used
            proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr,
                    newConf);
            assertEquals("number of connections in cache is wrong", 1, conns.size());
            // different conf, new connection should be set up
            newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
            proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr,
                    newConf);
            ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
            assertEquals("number of connections in cache is wrong", 2, connsArray.length);
            String p1 = connsArray[0].getServerPrincipal();
            String p2 = connsArray[1].getServerPrincipal();
            assertFalse("should have different principals", p1.equals(p2));
            assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1) || p1.equals(SERVER_PRINCIPAL_2));
            assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1) || p2.equals(SERVER_PRINCIPAL_2));
        } finally {
            server.stop();
            RPC.stopProxy(proxy1);
            RPC.stopProxy(proxy2);
            RPC.stopProxy(proxy3);
        }
    }

    static void testKerberosRpc(String principal, String keytab) throws Exception {
        final Configuration newConf = new Configuration(conf);
        newConf.set(SERVER_PRINCIPAL_KEY, principal);
        newConf.set(SERVER_KEYTAB_KEY, keytab);
        SecurityUtil.login(newConf, SERVER_KEYTAB_KEY, SERVER_PRINCIPAL_KEY);
        TestUserGroupInformation.verifyLoginMetrics(1, 0);
        UserGroupInformation current = UserGroupInformation.getCurrentUser();
        System.out.println("UGI: " + current);

        Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, newConf, null);
        TestSaslProtocol proxy = null;

        server.start();

        InetSocketAddress addr = NetUtils.getConnectAddress(server);
        try {
            proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr,
                    newConf);
            proxy.ping();
        } finally {
            server.stop();
            if (proxy != null) {
                RPC.stopProxy(proxy);
            }
        }
        System.out.println("Test is successful.");
    }

    public void testDigestAuthMethod(boolean useIp) throws Exception {
        setTokenServiceUseIp(useIp);

        TestTokenSecretManager sm = new TestTokenSecretManager();
        Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
        server.start();

        final UserGroupInformation current = UserGroupInformation.getCurrentUser();
        final InetSocketAddress addr = NetUtils.getConnectAddress(server);
        TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current.getUserName()));
        Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm);
        SecurityUtil.setTokenService(token, addr);
        LOG.info("Service IP address for token is " + token.getService());

        InetSocketAddress tokenAddr = SecurityUtil.getTokenServiceAddr(token);
        String expectedHost, gotHost;
        if (useIp) {
            expectedHost = addr.getAddress().getHostAddress();
            gotHost = tokenAddr.getAddress().getHostAddress();
        } else {
            gotHost = tokenAddr.getHostName();
            expectedHost = ADDRESS;
        }
        Assert.assertEquals(expectedHost, gotHost);
        Assert.assertEquals(expectedHost + ":" + addr.getPort(), token.getService().toString());

        current.addToken(token);

        current.doAs(new PrivilegedExceptionAction<Object>() {
            public Object run() throws IOException {
                TestSaslProtocol proxy = null;
                try {
                    proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID,
                            addr, conf);
                    Assert.assertEquals(AuthenticationMethod.TOKEN, proxy.getAuthMethod());
                } finally {
                    if (proxy != null) {
                        RPC.stopProxy(proxy);
                    }
                }
                return null;
            }
        });
        server.stop();
    }

    @Test
    public void testDigestAuthMethodIpBasedToken() throws Exception {
        testDigestAuthMethod(true);
    }

    @Test
    public void testDigestAuthMethodHostBasedToken() throws Exception {
        testDigestAuthMethod(false);
    }

    public static void main(String[] args) throws Exception {
        System.out.println("Testing Kerberos authentication over RPC");
        if (args.length != 2) {
            System.err.println(
                    "Usage: java <options> org.apache.hadoop.ipc.TestSaslRPC " + " <serverPrincipal> <keytabFile>");
            System.exit(-1);
        }
        String principal = args[0];
        String keytab = args[1];
        testKerberosRpc(principal, keytab);
    }

}