Java tutorial
/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.ipc; import static org.apache.hadoop.fs.CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION; import static org.junit.Assert.*; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.security.PrivilegedExceptionAction; import java.util.Set; import junit.framework.Assert; import org.apache.commons.logging.*; import org.apache.commons.logging.impl.Log4JLogger; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.Text; import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.KerberosInfo; import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenInfo; import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier; import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenSelector; import org.apache.hadoop.security.SaslInputStream; import org.apache.hadoop.security.SaslRpcClient; import org.apache.hadoop.security.SaslRpcServer; import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.TestUserGroupInformation; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.log4j.Level; import org.junit.BeforeClass; import org.junit.Test; /** Unit tests for using Sasl over RPC. */ public class TestSaslRPC { private static final String ADDRESS = "localhost"; public static final Log LOG = LogFactory.getLog(TestSaslRPC.class); static final String ERROR_MESSAGE = "Token is invalid"; static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal"; static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab"; static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR"; static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR"; private static Configuration conf; static { conf = new Configuration(); conf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos"); UserGroupInformation.setConfiguration(conf); } static { ((Log4JLogger) Client.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) Server.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SaslRpcClient.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SaslRpcServer.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SaslInputStream.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SecurityUtil.LOG).getLogger().setLevel(Level.ALL); } static Method setUseIpMethod; @BeforeClass public static void exposeUseIp() throws Exception { setUseIpMethod = SecurityUtil.class.getDeclaredMethod("setTokenServiceUseIp", boolean.class); setUseIpMethod.setAccessible(true); } private void setTokenServiceUseIp(boolean flag) throws Exception { setUseIpMethod.invoke(setUseIpMethod.getClass(), flag); } public static class TestTokenIdentifier extends AbstractDelegationTokenIdentifier { private Text tokenid; private Text realUser; final static Text KIND_NAME = new Text("test.token"); public TestTokenIdentifier() { this(new Text(), new Text()); } public TestTokenIdentifier(Text tokenid) { this(tokenid, new Text()); } public TestTokenIdentifier(Text tokenid, Text realUser) { this.tokenid = tokenid == null ? new Text() : tokenid; this.realUser = realUser == null ? new Text() : realUser; } @Override public Text getKind() { return KIND_NAME; } @Override public UserGroupInformation getUser() { if ("".equals(realUser.toString())) { return UserGroupInformation.createRemoteUser(tokenid.toString()); } else { UserGroupInformation realUgi = UserGroupInformation.createRemoteUser(realUser.toString()); return UserGroupInformation.createProxyUser(tokenid.toString(), realUgi); } } @Override public void readFields(DataInput in) throws IOException { tokenid.readFields(in); realUser.readFields(in); } @Override public void write(DataOutput out) throws IOException { tokenid.write(out); realUser.write(out); } } public static class TestTokenSecretManager extends SecretManager<TestTokenIdentifier> { public byte[] createPassword(TestTokenIdentifier id) { return id.getBytes(); } public byte[] retrievePassword(TestTokenIdentifier id) throws InvalidToken { return id.getBytes(); } public TestTokenIdentifier createIdentifier() { return new TestTokenIdentifier(); } } public static class BadTokenSecretManager extends TestTokenSecretManager { public byte[] retrievePassword(TestTokenIdentifier id) throws InvalidToken { throw new InvalidToken(ERROR_MESSAGE); } } public static class TestTokenSelector extends AbstractDelegationTokenSelector<TestTokenIdentifier> { TestTokenSelector() { super(TestTokenIdentifier.KIND_NAME); } } @KerberosInfo(serverPrincipal = SERVER_PRINCIPAL_KEY) @TokenInfo(TestTokenSelector.class) public interface TestSaslProtocol extends TestRPC.TestProtocol { public AuthenticationMethod getAuthMethod() throws IOException; } public static class TestSaslImpl extends TestRPC.TestImpl implements TestSaslProtocol { public AuthenticationMethod getAuthMethod() throws IOException { return UserGroupInformation.getCurrentUser().getAuthenticationMethod(); } } @Test public void testDigestRpc() throws Exception { TestTokenSecretManager sm = new TestTokenSecretManager(); final Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); doDigestRpc(server, sm); } @Test public void testSecureToInsecureRpc() throws Exception { Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, null); server.disableSecurity(); TestTokenSecretManager sm = new TestTokenSecretManager(); doDigestRpc(server, sm); } @Test public void testErrorMessage() throws Exception { BadTokenSecretManager sm = new BadTokenSecretManager(); final Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); boolean succeeded = false; try { doDigestRpc(server, sm); } catch (RemoteException e) { LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); assertTrue(ERROR_MESSAGE.equals(e.getLocalizedMessage())); assertTrue(e.unwrapRemoteException() instanceof InvalidToken); succeeded = true; } assertTrue(succeeded); } private void doDigestRpc(Server server, TestTokenSecretManager sm) throws Exception { server.start(); final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final InetSocketAddress addr = NetUtils.getConnectAddress(server); TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current.getUserName())); Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm); SecurityUtil.setTokenService(token, addr); LOG.info("Service IP address for token is " + token.getService()); current.addToken(token); TestSaslProtocol proxy = null; try { proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, conf); proxy.ping(); } finally { server.stop(); if (proxy != null) { RPC.stopProxy(proxy); } } } @Test public void testGetRemotePrincipal() throws Exception { try { Configuration newConf = new Configuration(conf); newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); ConnectionId remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0), TestSaslProtocol.class, null, newConf); assertEquals(SERVER_PRINCIPAL_1, remoteId.getServerPrincipal()); // this following test needs security to be off newConf.set(HADOOP_SECURITY_AUTHENTICATION, "simple"); UserGroupInformation.setConfiguration(newConf); remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0), TestSaslProtocol.class, null, newConf); assertEquals("serverPrincipal should be null when security is turned off", null, remoteId.getServerPrincipal()); } finally { // revert back to security is on UserGroupInformation.setConfiguration(conf); } } @Test public void testPerConnectionConf() throws Exception { TestTokenSecretManager sm = new TestTokenSecretManager(); final Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); server.start(); final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final InetSocketAddress addr = NetUtils.getConnectAddress(server); TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current.getUserName())); Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm); SecurityUtil.setTokenService(token, addr); LOG.info("Service IP address for token is " + token.getService()); current.addToken(token); Configuration newConf = new Configuration(conf); newConf.set("hadoop.rpc.socket.factory.class.default", ""); newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); TestSaslProtocol proxy1 = null; TestSaslProtocol proxy2 = null; TestSaslProtocol proxy3 = null; try { proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); Client client = RPC.getClient(conf); Set<ConnectionId> conns = client.getConnectionIds(); assertEquals("number of connections in cache is wrong", 1, conns.size()); // same conf, connection should be re-used proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); assertEquals("number of connections in cache is wrong", 1, conns.size()); // different conf, new connection should be set up newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2); proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]); assertEquals("number of connections in cache is wrong", 2, connsArray.length); String p1 = connsArray[0].getServerPrincipal(); String p2 = connsArray[1].getServerPrincipal(); assertFalse("should have different principals", p1.equals(p2)); assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1) || p1.equals(SERVER_PRINCIPAL_2)); assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1) || p2.equals(SERVER_PRINCIPAL_2)); } finally { server.stop(); RPC.stopProxy(proxy1); RPC.stopProxy(proxy2); RPC.stopProxy(proxy3); } } static void testKerberosRpc(String principal, String keytab) throws Exception { final Configuration newConf = new Configuration(conf); newConf.set(SERVER_PRINCIPAL_KEY, principal); newConf.set(SERVER_KEYTAB_KEY, keytab); SecurityUtil.login(newConf, SERVER_KEYTAB_KEY, SERVER_PRINCIPAL_KEY); TestUserGroupInformation.verifyLoginMetrics(1, 0); UserGroupInformation current = UserGroupInformation.getCurrentUser(); System.out.println("UGI: " + current); Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, newConf, null); TestSaslProtocol proxy = null; server.start(); InetSocketAddress addr = NetUtils.getConnectAddress(server); try { proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); proxy.ping(); } finally { server.stop(); if (proxy != null) { RPC.stopProxy(proxy); } } System.out.println("Test is successful."); } public void testDigestAuthMethod(boolean useIp) throws Exception { setTokenServiceUseIp(useIp); TestTokenSecretManager sm = new TestTokenSecretManager(); Server server = RPC.getServer(new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); server.start(); final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final InetSocketAddress addr = NetUtils.getConnectAddress(server); TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current.getUserName())); Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm); SecurityUtil.setTokenService(token, addr); LOG.info("Service IP address for token is " + token.getService()); InetSocketAddress tokenAddr = SecurityUtil.getTokenServiceAddr(token); String expectedHost, gotHost; if (useIp) { expectedHost = addr.getAddress().getHostAddress(); gotHost = tokenAddr.getAddress().getHostAddress(); } else { gotHost = tokenAddr.getHostName(); expectedHost = ADDRESS; } Assert.assertEquals(expectedHost, gotHost); Assert.assertEquals(expectedHost + ":" + addr.getPort(), token.getService().toString()); current.addToken(token); current.doAs(new PrivilegedExceptionAction<Object>() { public Object run() throws IOException { TestSaslProtocol proxy = null; try { proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, conf); Assert.assertEquals(AuthenticationMethod.TOKEN, proxy.getAuthMethod()); } finally { if (proxy != null) { RPC.stopProxy(proxy); } } return null; } }); server.stop(); } @Test public void testDigestAuthMethodIpBasedToken() throws Exception { testDigestAuthMethod(true); } @Test public void testDigestAuthMethodHostBasedToken() throws Exception { testDigestAuthMethod(false); } public static void main(String[] args) throws Exception { System.out.println("Testing Kerberos authentication over RPC"); if (args.length != 2) { System.err.println( "Usage: java <options> org.apache.hadoop.ipc.TestSaslRPC " + " <serverPrincipal> <keytabFile>"); System.exit(-1); } String principal = args[0]; String keytab = args[1]; testKerberosRpc(principal, keytab); } }