co.cask.common.security.server.ExternalAuthenticationServerTestBase.java Source code

Java tutorial

Introduction

Here is the source code for co.cask.common.security.server.ExternalAuthenticationServerTestBase.java

Source

/*
 * Copyright  2014 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.common.security.server;

import co.cask.common.Networks;
import co.cask.common.io.Codec;
import co.cask.common.security.Constants;
import co.cask.common.security.authentication.AccessToken;
import co.cask.common.security.authentication.AccessTokenCodec;
import co.cask.common.security.config.SecurityConfiguration;
import co.cask.common.security.guice.ConfigModule;
import co.cask.common.security.guice.DiscoveryRuntimeModule;
import co.cask.common.security.guice.IOModule;
import co.cask.common.security.guice.InMemorySecurityModule;
import com.google.common.base.Throwables;
import com.google.common.collect.Sets;
import com.google.common.io.ByteStreams;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.inject.AbstractModule;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Module;
import com.google.inject.name.Names;
import com.google.inject.util.Modules;
import com.unboundid.ldap.listener.InMemoryDirectoryServer;
import com.unboundid.ldap.listener.InMemoryDirectoryServerConfig;
import com.unboundid.ldap.listener.InMemoryListenerConfig;
import com.unboundid.ldap.sdk.Entry;
import org.apache.commons.codec.binary.Base64;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.twill.discovery.Discoverable;
import org.apache.twill.discovery.DiscoveryServiceClient;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.net.SocketAddress;
import java.net.URL;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.contains;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;

/**
 * Base test class for ExternalAuthenticationServer.
 */
public abstract class ExternalAuthenticationServerTestBase {
    private static final Logger LOG = LoggerFactory.getLogger(ExternalAuthenticationServerTestBase.class);
    private static ExternalAuthenticationServer server;
    private static int port;
    private static Codec<AccessToken> tokenCodec;
    private static DiscoveryServiceClient discoveryServiceClient;
    private static InMemoryDirectoryServer ldapServer;
    protected static int ldapPort = Networks.getRandomPort();

    private static final Logger TEST_AUDIT_LOGGER = mock(Logger.class);

    // Needs to be set by derived classes.
    protected static SecurityConfiguration configuration;
    protected static InMemoryListenerConfig ldapListenerConfig;

    protected abstract String getProtocol();

    protected abstract HttpClient getHTTPClient() throws Exception;

    protected static void setup() throws Exception {
        Assert.assertNotNull("SecurityConfiguration needs to be set by derived classes", configuration);

        Module securityModule = Modules.override(new InMemorySecurityModule()).with(new AbstractModule() {
            @Override
            protected void configure() {
                bind(AuditLogHandler.class)
                        .annotatedWith(Names.named(ExternalAuthenticationServer.NAMED_EXTERNAL_AUTH))
                        .toInstance(new AuditLogHandler(TEST_AUDIT_LOGGER));
            }
        });
        Injector injector = Guice.createInjector(new IOModule(), securityModule,
                new ConfigModule(getSecurityConfiguration(configuration)),
                new DiscoveryRuntimeModule().getInMemoryModules());
        server = injector.getInstance(ExternalAuthenticationServer.class);
        tokenCodec = injector.getInstance(AccessTokenCodec.class);
        discoveryServiceClient = injector.getInstance(DiscoveryServiceClient.class);

        if (configuration.getBoolean(Constants.SSL_ENABLED)) {
            port = configuration.getInt(Constants.AuthenticationServer.SSL_PORT);
        } else {
            port = configuration.getInt(Constants.AUTH_SERVER_BIND_PORT);
        }

        try {
            startLDAPServer();
        } catch (Exception e) {
            throw Throwables.propagate(e);
        }

        server.startAndWait();
        LOG.info("Auth server running on port {}", port);
        ldapServer.startListening();
        TimeUnit.SECONDS.sleep(3);
    }

    public int getAuthServerPort() {
        return port;
    }

    /**
     * LDAP server and related handler configurations.
     */
    private static SecurityConfiguration getSecurityConfiguration(SecurityConfiguration cConf) {
        String configBase = Constants.AUTH_HANDLER_CONFIG_BASE;

        // Use random port for testing
        cConf.setInt(Constants.AUTH_SERVER_BIND_PORT, Networks.getRandomPort());
        cConf.setInt(Constants.AuthenticationServer.SSL_PORT, Networks.getRandomPort());

        cConf.set(Constants.AUTH_HANDLER_CLASS, LDAPAuthenticationHandler.class.getName());
        cConf.set(Constants.LOGIN_MODULE_CLASS_NAME, LDAPLoginModule.class.getName());
        cConf.set(configBase.concat("debug"), "true");
        cConf.set(configBase.concat("hostname"), "localhost");
        cConf.set(configBase.concat("port"), Integer.toString(ldapPort));
        cConf.set(configBase.concat("userBaseDn"), "dc=example,dc=com");
        cConf.set(configBase.concat("userRdnAttribute"), "cn");
        cConf.set(configBase.concat("userObjectClass"), "inetorgperson");

        URL keytabUrl = ExternalAuthenticationServerTestBase.class.getClassLoader().getResource("test.keytab");
        Assert.assertNotNull(keytabUrl);
        cConf.set(Constants.CFG_CDAP_MASTER_KRB_KEYTAB_PATH, keytabUrl.getPath());
        cConf.set(Constants.CFG_CDAP_MASTER_KRB_PRINCIPAL, "test_principal");
        return cConf;
    }

    private static void startLDAPServer() throws Exception {
        InMemoryDirectoryServerConfig config = new InMemoryDirectoryServerConfig("dc=example,dc=com");
        config.setListenerConfigs(ldapListenerConfig);

        Entry defaultEntry = new Entry("dn: dc=example,dc=com", "objectClass: top", "objectClass: domain",
                "dc: example");
        Entry userEntry = new Entry("dn: uid=user,dc=example,dc=com", "objectClass: inetorgperson", "cn: admin",
                "sn: User", "uid: user", "userPassword: realtime");

        ldapServer = new InMemoryDirectoryServer(config);
        ldapServer.addEntries(defaultEntry, userEntry);
    }

    @AfterClass
    public static void afterClass() throws Exception {
        ldapServer.shutDown(true);
        server.stopAndWait();
        // Clear any security properties for zookeeper.
        System.clearProperty(Constants.External.Zookeeper.ENV_AUTH_PROVIDER_1);
    }

    /**
     * Test an authorized request to server.
     * @throws Exception
     */
    @Test
    public void testValidAuthentication() throws Exception {
        HttpClient client = getHTTPClient();
        String uri = String.format("%s://%s:%d/%s", getProtocol(),
                server.getSocketAddress().getAddress().getHostAddress(), server.getSocketAddress().getPort(),
                GrantAccessToken.Paths.GET_TOKEN);

        HttpGet request = new HttpGet(uri);
        request.addHeader("Authorization", "Basic YWRtaW46cmVhbHRpbWU=");
        HttpResponse response = client.execute(request);

        assertEquals(response.getStatusLine().getStatusCode(), 200);
        verify(TEST_AUDIT_LOGGER, timeout(10000).atLeastOnce()).trace(contains("admin"));

        // Test correct headers being returned
        String cacheControlHeader = response.getFirstHeader("Cache-Control").getValue();
        String pragmaHeader = response.getFirstHeader("Pragma").getValue();
        String contentType = response.getFirstHeader("Content-Type").getValue();

        assertEquals("no-store", cacheControlHeader);
        assertEquals("no-cache", pragmaHeader);
        assertEquals("application/json;charset=UTF-8", contentType);

        // Test correct response body
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        ByteStreams.copy(response.getEntity().getContent(), bos);
        String responseBody = bos.toString("UTF-8");
        bos.close();

        JsonParser parser = new JsonParser();
        JsonObject responseJson = (JsonObject) parser.parse(responseBody);
        String tokenType = responseJson.get(ExternalAuthenticationServer.ResponseFields.TOKEN_TYPE).toString();
        long expiration = responseJson.get(ExternalAuthenticationServer.ResponseFields.EXPIRES_IN).getAsLong();

        assertEquals(String.format("\"%s\"", ExternalAuthenticationServer.ResponseFields.TOKEN_TYPE_BODY),
                tokenType);

        long expectedExpiration = configuration.getInt(Constants.TOKEN_EXPIRATION);
        // Test expiration time in seconds
        assertEquals(expectedExpiration / 1000, expiration);

        // Test that the server passes back an AccessToken object which can be decoded correctly.
        String encodedToken = responseJson.get(ExternalAuthenticationServer.ResponseFields.ACCESS_TOKEN)
                .getAsString();
        AccessToken token = tokenCodec.decode(Base64.decodeBase64(encodedToken));
        assertEquals("admin", token.getIdentifier().getUsername());
        LOG.info("AccessToken got from ExternalAuthenticationServer is: " + encodedToken);
    }

    /**
     * Test an unauthorized request to server.
     * @throws Exception
     */
    @Test
    public void testInvalidAuthentication() throws Exception {
        HttpClient client = getHTTPClient();
        String uri = String.format("%s://%s:%d/%s", getProtocol(),
                server.getSocketAddress().getAddress().getHostAddress(), server.getSocketAddress().getPort(),
                GrantAccessToken.Paths.GET_TOKEN);
        HttpGet request = new HttpGet(uri);
        request.addHeader("Authorization", "xxxxx");

        HttpResponse response = client.execute(request);

        // Request is Unauthorized
        assertEquals(401, response.getStatusLine().getStatusCode());
        verify(TEST_AUDIT_LOGGER, timeout(10000).atLeastOnce()).trace(contains("401"));
    }

    /**
     * Test an unauthorized status request to server.
     * @throws Exception
     */
    @Test
    public void testStatusResponse() throws Exception {
        HttpClient client = getHTTPClient();
        String uri = String.format("%s://%s:%d/%s", getProtocol(),
                server.getSocketAddress().getAddress().getHostAddress(), server.getSocketAddress().getPort(),
                Constants.EndPoints.STATUS);
        HttpGet request = new HttpGet(uri);

        HttpResponse response = client.execute(request);

        // Status request is authorized without any extra headers
        assertEquals(200, response.getStatusLine().getStatusCode());

    }

    /**
     * Test getting a long lasting Access Token.
     * @throws Exception
     */
    @Test
    public void testExtendedToken() throws Exception {
        HttpClient client = getHTTPClient();
        String uri = String.format("%s://%s:%d/%s", getProtocol(),
                server.getSocketAddress().getAddress().getHostAddress(), server.getSocketAddress().getPort(),
                GrantAccessToken.Paths.GET_EXTENDED_TOKEN);
        HttpGet request = new HttpGet(uri);
        request.addHeader("Authorization", "Basic YWRtaW46cmVhbHRpbWU=");
        HttpResponse response = client.execute(request);

        assertEquals(200, response.getStatusLine().getStatusCode());

        // Test correct response body
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        ByteStreams.copy(response.getEntity().getContent(), bos);
        String responseBody = bos.toString("UTF-8");
        bos.close();

        JsonParser parser = new JsonParser();
        JsonObject responseJson = (JsonObject) parser.parse(responseBody);
        long expiration = responseJson.get(ExternalAuthenticationServer.ResponseFields.EXPIRES_IN).getAsLong();

        long expectedExpiration = configuration.getInt(Constants.EXTENDED_TOKEN_EXPIRATION);
        // Test expiration time in seconds
        assertEquals(expectedExpiration / 1000, expiration);

        // Test that the server passes back an AccessToken object which can be decoded correctly.
        String encodedToken = responseJson.get(ExternalAuthenticationServer.ResponseFields.ACCESS_TOKEN)
                .getAsString();
        AccessToken token = tokenCodec.decode(Base64.decodeBase64(encodedToken));
        assertEquals("admin", token.getIdentifier().getUsername());
        LOG.info("AccessToken got from ExternalAuthenticationServer is: " + encodedToken);
    }

    /**
     * Test that invalid paths return a 404 Not Found.
     * @throws Exception
     */
    @Test
    public void testInvalidPath() throws Exception {
        HttpClient client = getHTTPClient();
        String uri = String.format("%s://%s:%d/%s", getProtocol(),
                server.getSocketAddress().getAddress().getHostAddress(), server.getSocketAddress().getPort(),
                "invalid");
        HttpGet request = new HttpGet(uri);
        request.addHeader("Authorization", "Basic YWRtaW46cmVhbHRpbWU=");
        HttpResponse response = client.execute(request);

        assertEquals(404, response.getStatusLine().getStatusCode());

    }

    /**
     * Test that the service is discoverable.
     * @throws Exception
     */
    @Test
    public void testServiceRegistration() throws Exception {
        Iterable<Discoverable> discoverables = discoveryServiceClient
                .discover(Constants.Service.EXTERNAL_AUTHENTICATION);

        Set<SocketAddress> addresses = Sets.newHashSet();
        for (Discoverable discoverable : discoverables) {
            addresses.add(discoverable.getSocketAddress());
        }

        Assert.assertTrue(addresses.contains(server.getSocketAddress()));
    }
}