com.hortonworks.registries.auth.client.AuthenticatorTestCase.java Source code

Java tutorial

Introduction

Here is the source code for com.hortonworks.registries.auth.client.AuthenticatorTestCase.java

Source

/**
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License. See accompanying LICENSE file.
 */
package com.hortonworks.registries.auth.client;

import org.apache.catalina.deploy.FilterDef;
import org.apache.catalina.deploy.FilterMap;
import org.apache.catalina.startup.Tomcat;
import com.hortonworks.registries.auth.server.AuthenticationFilter;
import org.apache.http.HttpResponse;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.Credentials;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.params.AuthPolicy;
import org.apache.http.entity.InputStreamEntity;
import org.apache.http.impl.auth.SPNegoSchemeFactory;
import org.apache.http.impl.client.SystemDefaultHttpClient;
import org.apache.http.util.EntityUtils;
import org.mortbay.jetty.Server;
import org.mortbay.jetty.servlet.Context;
import org.mortbay.jetty.servlet.FilterHolder;
import org.mortbay.jetty.servlet.ServletHolder;

import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.InputStreamReader;
import java.io.Writer;
import java.net.HttpURLConnection;
import java.net.ServerSocket;
import java.net.URL;
import java.security.Principal;
import java.util.Properties;

import org.junit.Assert;

public class AuthenticatorTestCase {
    private Server server;
    private String host = null;
    private int port = -1;
    private boolean useTomcat = false;
    private Tomcat tomcat = null;
    Context context;

    private static Properties authenticatorConfig;

    public AuthenticatorTestCase() {
    }

    public AuthenticatorTestCase(boolean useTomcat) {
        this.useTomcat = useTomcat;
    }

    protected static void setAuthenticationHandlerConfig(Properties config) {
        authenticatorConfig = config;
    }

    public static class TestFilter extends AuthenticationFilter {

        @Override
        protected Properties getConfiguration(String configPrefix, FilterConfig filterConfig)
                throws ServletException {
            return authenticatorConfig;
        }
    }

    @SuppressWarnings("serial")
    public static class TestServlet extends HttpServlet {

        @Override
        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
                throws ServletException, IOException {
            resp.setStatus(HttpServletResponse.SC_OK);
        }

        @Override
        protected void doPost(HttpServletRequest req, HttpServletResponse resp)
                throws ServletException, IOException {
            InputStream is = req.getInputStream();
            OutputStream os = resp.getOutputStream();
            int c = is.read();
            while (c > -1) {
                os.write(c);
                c = is.read();
            }
            is.close();
            os.close();
            resp.setStatus(HttpServletResponse.SC_OK);
        }
    }

    protected int getLocalPort() throws Exception {
        ServerSocket ss = new ServerSocket(0);
        int ret = ss.getLocalPort();
        ss.close();
        return ret;
    }

    protected void start() throws Exception {
        if (useTomcat)
            startTomcat();
        else
            startJetty();
    }

    protected void startJetty() throws Exception {
        server = new Server(0);
        context = new Context();
        context.setContextPath("/foo");
        server.setHandler(context);
        context.addFilter(new FilterHolder(TestFilter.class), "/*", 0);
        context.addServlet(new ServletHolder(TestServlet.class), "/bar");
        host = "localhost";
        port = getLocalPort();
        server.getConnectors()[0].setHost(host);
        server.getConnectors()[0].setPort(port);
        server.start();
        System.out.println("Running embedded servlet container at: http://" + host + ":" + port);
    }

    protected void startTomcat() throws Exception {
        tomcat = new Tomcat();
        File base = new File(System.getProperty("java.io.tmpdir"));
        org.apache.catalina.Context ctx = tomcat.addContext("/foo", base.getAbsolutePath());
        FilterDef fd = new FilterDef();
        fd.setFilterClass(TestFilter.class.getName());
        fd.setFilterName("TestFilter");
        FilterMap fm = new FilterMap();
        fm.setFilterName("TestFilter");
        fm.addURLPattern("/*");
        fm.addServletName("/bar");
        ctx.addFilterDef(fd);
        ctx.addFilterMap(fm);
        tomcat.addServlet(ctx, "/bar", TestServlet.class.getName());
        ctx.addServletMapping("/bar", "/bar");
        host = "localhost";
        port = getLocalPort();
        tomcat.setHostname(host);
        tomcat.setPort(port);
        tomcat.start();
    }

    protected void stop() throws Exception {
        if (useTomcat)
            stopTomcat();
        else
            stopJetty();
    }

    protected void stopJetty() throws Exception {
        try {
            server.stop();
        } catch (Exception e) {
        }

        try {
            server.destroy();
        } catch (Exception e) {
        }
    }

    protected void stopTomcat() throws Exception {
        try {
            tomcat.stop();
        } catch (Exception e) {
        }

        try {
            tomcat.destroy();
        } catch (Exception e) {
        }
    }

    protected String getBaseURL() {
        return "http://" + host + ":" + port + "/foo/bar";
    }

    private static class TestConnectionConfigurator implements ConnectionConfigurator {
        boolean invoked;

        @Override
        public HttpURLConnection configure(HttpURLConnection conn) throws IOException {
            invoked = true;
            return conn;
        }
    }

    private String POST = "test";

    protected void _testAuthentication(Authenticator authenticator, boolean doPost) throws Exception {
        start();
        try {
            URL url = new URL(getBaseURL());
            AuthenticatedURL.Token token = new AuthenticatedURL.Token();
            Assert.assertFalse(token.isSet());
            TestConnectionConfigurator connConf = new TestConnectionConfigurator();
            AuthenticatedURL aUrl = new AuthenticatedURL(authenticator, connConf);
            HttpURLConnection conn = aUrl.openConnection(url, token);
            Assert.assertTrue(connConf.invoked);
            String tokenStr = token.toString();
            if (doPost) {
                conn.setRequestMethod("POST");
                conn.setDoOutput(true);
            }
            conn.connect();
            if (doPost) {
                Writer writer = new OutputStreamWriter(conn.getOutputStream());
                writer.write(POST);
                writer.close();
            }
            Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode());
            if (doPost) {
                BufferedReader reader = new BufferedReader(new InputStreamReader(conn.getInputStream()));
                String echo = reader.readLine();
                Assert.assertEquals(POST, echo);
                Assert.assertNull(reader.readLine());
            }
            aUrl = new AuthenticatedURL();
            conn = aUrl.openConnection(url, token);
            conn.connect();
            Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode());
            Assert.assertEquals(tokenStr, token.toString());
        } finally {
            stop();
        }
    }

    private SystemDefaultHttpClient getHttpClient() {
        final SystemDefaultHttpClient httpClient = new SystemDefaultHttpClient();
        httpClient.getAuthSchemes().register(AuthPolicy.SPNEGO, new SPNegoSchemeFactory(true));
        Credentials use_jaas_creds = new Credentials() {
            public String getPassword() {
                return null;
            }

            public Principal getUserPrincipal() {
                return null;
            }
        };

        httpClient.getCredentialsProvider().setCredentials(AuthScope.ANY, use_jaas_creds);
        return httpClient;
    }

    private void doHttpClientRequest(HttpClient httpClient, HttpUriRequest request) throws Exception {
        HttpResponse response = null;
        try {
            response = httpClient.execute(request);
            final int httpStatus = response.getStatusLine().getStatusCode();
            Assert.assertEquals(HttpURLConnection.HTTP_OK, httpStatus);
        } finally {
            if (response != null)
                EntityUtils.consumeQuietly(response.getEntity());
        }
    }

    protected void _testAuthenticationHttpClient(Authenticator authenticator, boolean doPost) throws Exception {
        start();
        try {
            SystemDefaultHttpClient httpClient = getHttpClient();
            doHttpClientRequest(httpClient, new HttpGet(getBaseURL()));

            // Always do a GET before POST to trigger the SPNego negotiation
            if (doPost) {
                HttpPost post = new HttpPost(getBaseURL());
                byte[] postBytes = POST.getBytes();
                ByteArrayInputStream bis = new ByteArrayInputStream(postBytes);
                InputStreamEntity entity = new InputStreamEntity(bis, postBytes.length);

                // Important that the entity is not repeatable -- this means if
                // we have to renegotiate (e.g. b/c the cookie wasn't handled properly)
                // the test will fail.
                Assert.assertFalse(entity.isRepeatable());
                post.setEntity(entity);
                doHttpClientRequest(httpClient, post);
            }
        } finally {
            stop();
        }
    }
}