co.cask.cdap.security.tools.AccessTokenClient.java Source code

Java tutorial

Introduction

Here is the source code for co.cask.cdap.security.tools.AccessTokenClient.java

Source

/*
 * Copyright  2014 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package co.cask.cdap.security.tools;

import co.cask.cdap.common.utils.UsageException;
import co.cask.cdap.security.server.ExternalAuthenticationServer;
import com.google.common.io.ByteStreams;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.codec.binary.Base64;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.conn.ClientConnectionManager;
import org.apache.http.conn.scheme.Scheme;
import org.apache.http.conn.scheme.SchemeRegistry;
import org.apache.http.conn.ssl.SSLSocketFactory;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.impl.conn.BasicClientConnectionManager;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import java.io.ByteArrayOutputStream;
import java.io.Console;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.net.URI;
import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;

/**
 * Client to get an AccessToken using username:password authentication.
 */
public class AccessTokenClient {
    static {
        // this turns off all logging but we don't need that for a cmdline tool
        Logger.getRootLogger().setLevel(Level.OFF);
    }

    /**
     * for debugging. should only be set to true in unit tests.
     * when true, program will print the stack trace after the usage.
     */
    public static boolean debug = false;

    private boolean help = false;
    // Default por numbers
    private static final int SSL_PORT = 10443;
    private static final int NO_SSL_PORT = 10000;

    private String host;
    // default will be set in parsing arguments
    private int port = NO_SSL_PORT;
    private boolean useSsl = false;
    private boolean disableCertCheck = false;

    private String username;
    private String password;
    private String filePath;
    private Options options;

    private static final class ConfigurableOptions {
        private static final String HOST = "host";
        private static final String PORT = "port";
        private static final String USER_NAME = "username";
        private static final String PASSWORD = "password";
        private static final String FILE = "file";
        private static final String HELP = "help";
        private static final String SSL = "ssl";
        private static final String DISABLE_CERT_CHECK = "disable-cert-check";
    }

    /**
     * Print the usage statement and return null (or empty string if this is not
     * an error case). See getValue() for an explanation of the return type.
     *
     * @param error indicates whether this was invoked as the result of an error
     * @throws co.cask.cdap.common.utils.UsageException
     *          in case of error
     */
    void usage(boolean error) {
        PrintStream out = (error ? System.err : System.out);
        String name = "accesstoken-client";
        if (System.getProperty("script") != null) {
            name = System.getProperty("script").replaceAll("[./]", "");
        }
        out.println("Usage: ");
        out.println("  " + name + " [ --host <host> ] [ --username <username> ] [ --file <outputfile> ]");
        out.println();
        printOptions(error);
    }

    private void printOptions(boolean error) {
        PrintWriter pw = error ? new PrintWriter(System.err) : new PrintWriter(System.out);
        pw.println("Options:\n");
        HelpFormatter formatter = new HelpFormatter();
        formatter.printOptions(pw, 100, options, 0, 10);
        pw.flush();
        pw.close();
        if (error) {
            throw new UsageException();
        }
    }

    private void buildOptions() {
        options = new Options();
        options.addOption(null, ConfigurableOptions.HOST, true, "To specify the host of gateway");
        options.addOption(null, ConfigurableOptions.PORT, true, "To specify the port of gateway. " + String
                .format("Defaults to %d if router is not SSL enabled and %d if it is.", NO_SSL_PORT, SSL_PORT));
        options.addOption(null, ConfigurableOptions.USER_NAME, true, "To specify the user to login as");
        options.addOption(null, ConfigurableOptions.PASSWORD, true, "To specify the user password");
        options.addOption(null, ConfigurableOptions.FILE, true, "To specify the access token file");
        options.addOption(null, ConfigurableOptions.HELP, false, "To print this message");
        options.addOption(null, ConfigurableOptions.SSL, false, "To specify that SSL is enabled");
        options.addOption(null, ConfigurableOptions.DISABLE_CERT_CHECK, false,
                "To specify whether to check for properly signed certificates");

    }

    /**
     * Print an error message followed by the usage statement.
     *
     * @param errorMessage the error message
     */
    void usage(String errorMessage) {
        if (errorMessage != null) {
            System.err.println("Error: " + errorMessage);
        }
        usage(true);
    }

    /**
     * Parse the command line arguments.
     */
    void parseArguments(String[] args) {
        CommandLineParser parser = new BasicParser();
        CommandLine commandLine = null;
        try {
            commandLine = parser.parse(options, args);
        } catch (org.apache.commons.cli.ParseException e) {
            System.err.println("Could not parse arguments correctly.");
            usage(true);
        }

        if (commandLine.hasOption(ConfigurableOptions.HELP)) {
            usage(false);
            help = true;
            return;
        }

        useSsl = commandLine.hasOption(ConfigurableOptions.SSL);
        disableCertCheck = commandLine.hasOption(ConfigurableOptions.DISABLE_CERT_CHECK);

        host = commandLine.getOptionValue(ConfigurableOptions.HOST, "localhost");

        if (commandLine.hasOption(ConfigurableOptions.PORT)) {
            try {
                port = Integer.parseInt(commandLine.getOptionValue(ConfigurableOptions.PORT));
            } catch (NumberFormatException e) {
                usage("--port must be an integer value");
            }
        } else {
            port = (useSsl) ? SSL_PORT : NO_SSL_PORT;
        }

        username = commandLine.getOptionValue(ConfigurableOptions.USER_NAME, System.getProperty("user.name"));

        if (username == null) {
            usage("Specify --username to login as a user.");
        }

        password = commandLine.getOptionValue(ConfigurableOptions.PASSWORD);

        if (password == null) {
            Console console = System.console();
            password = String.valueOf(console.readPassword(String.format("Password for %s: ", username)));
        }

        if (commandLine.hasOption(ConfigurableOptions.FILE)) {
            filePath = commandLine.getOptionValue(ConfigurableOptions.FILE);
        } else {
            usage("Specify --file to save to file");
        }

        if (commandLine.getArgs().length > 0) {
            usage(true);
        }
    }

    private String getAuthenticationServerAddress() throws IOException {
        HttpClient client = new DefaultHttpClient();
        String baseUrl = "http";
        // ssl settings
        if (useSsl) {
            baseUrl = "https";
            if (disableCertCheck) {
                try {
                    client = getHTTPClient();
                } catch (Exception e) {
                    errorDebugExit("Could not create HTTP Client with SSL enabled", e);
                    System.exit(1);
                }
            }
        }
        HttpGet get = new HttpGet(String.format("%s://%s:%d", baseUrl, host, port));
        HttpResponse response = client.execute(get);

        if (response.getStatusLine().getStatusCode() == 200) {
            System.out.println("Security is not enabled. No Access Token may be acquired");
            System.exit(0);
        }
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        ByteStreams.copy(response.getEntity().getContent(), bos);
        String responseBody = bos.toString("UTF-8");
        bos.close();
        JsonParser parser = new JsonParser();
        JsonObject responseJson = (JsonObject) parser.parse(responseBody);
        JsonArray addresses = responseJson.get("auth_uri").getAsJsonArray();
        ArrayList<String> list = new ArrayList<String>();
        for (JsonElement e : addresses) {
            list.add(e.getAsString());
        }
        return list.get(new Random().nextInt(list.size()));
    }

    protected DefaultHttpClient getHTTPClient() throws Exception {
        SSLContext sslContext = SSLContext.getInstance("SSL");

        // set up a TrustManager that trusts everything
        sslContext.init(null, new TrustManager[] { new X509TrustManager() {
            @Override
            public java.security.cert.X509Certificate[] getAcceptedIssuers() {
                return null;
            }

            @Override
            public void checkClientTrusted(java.security.cert.X509Certificate[] x509Certificates, String s)
                    throws CertificateException {
                //
            }

            @Override
            public void checkServerTrusted(java.security.cert.X509Certificate[] x509Certificates, String s)
                    throws CertificateException {
                //
            }

        } }, new SecureRandom());

        SSLSocketFactory sf = new SSLSocketFactory(sslContext);
        Scheme httpsScheme = new Scheme("https", 10101, sf);
        SchemeRegistry schemeRegistry = new SchemeRegistry();
        schemeRegistry.register(httpsScheme);

        // apache HttpClient version >4.2 should use BasicClientConnectionManager
        ClientConnectionManager cm = new BasicClientConnectionManager(schemeRegistry);
        return new DefaultHttpClient(cm);
    }

    public String execute0(String[] args) {
        buildOptions();
        parseArguments(args);
        if (help) {
            return "";
        }

        String baseUrl;
        try {
            baseUrl = getAuthenticationServerAddress();
        } catch (IOException e) {
            errorDebugExit("Could not find Authentication service to connect to.", e);
            return null;
        }

        System.out.println(String.format("Authentication server address is: %s", baseUrl));
        System.out.println(String.format("Authenticating as: %s", username));

        HttpClient client = new DefaultHttpClient();
        if (useSsl && disableCertCheck) {
            try {
                client = getHTTPClient();
            } catch (Exception e) {
                errorDebugExit("Could not create HTTP Client with SSL enabled", e);
                return null;
            }
        }

        // construct the full URL and verify its well-formedness
        try {
            URI.create(baseUrl);
        } catch (IllegalArgumentException e) {
            System.err.println(
                    "Invalid base URL '" + baseUrl + "'. Check the validity of --host or --port arguments.");
            return null;
        }

        HttpGet get = new HttpGet(baseUrl);
        String auth = Base64.encodeBase64String(String.format("%s:%s", username, password).getBytes());
        auth = auth.replaceAll("(\r|\n)", "");
        get.addHeader("Authorization", String.format("Basic %s", auth));
        HttpResponse response;
        try {
            response = client.execute(get);
        } catch (IOException e) {
            errorDebugExit("Error sending HTTP request: " + e.getMessage(), e);
            return null;
        }
        if (response.getStatusLine().getStatusCode() != HttpStatus.SC_OK) {
            System.out.println(
                    "Authentication failed. Please ensure that the username and password provided are correct.");
            return null;
        } else {
            try {
                ByteArrayOutputStream bos = new ByteArrayOutputStream();
                ByteStreams.copy(response.getEntity().getContent(), bos);
                String responseBody = bos.toString("UTF-8");
                bos.close();
                JsonParser parser = new JsonParser();
                JsonObject responseJson = (JsonObject) parser.parse(responseBody);
                String token = responseJson.get(ExternalAuthenticationServer.ResponseFields.ACCESS_TOKEN)
                        .getAsString();

                PrintWriter writer = new PrintWriter(filePath, "UTF-8");
                writer.write(token);
                writer.close();
                System.out.println("Your Access Token is:" + token);
                System.out.println("Access Token saved to file " + filePath);
            } catch (Exception e) {
                System.err.println("Could not parse response contents.");
                e.printStackTrace(System.err);
                return null;
            }
        }
        client.getConnectionManager().shutdown();
        return "OK.";
    }

    /**
     * Prints the message, if debug is enabled, prints the error message stacktrace, then exits
     * @param message
     * @param e
     */
    private void errorDebugExit(String message, Exception e) {
        System.err.println(message);
        if (debug) {
            e.printStackTrace();
        }
        System.exit(1);
    }

    public String execute(String[] args) {
        try {
            return execute0(args);
        } catch (UsageException e) {
            if (debug) {
                System.err.println("Exception for arguments: " + Arrays.toString(args) + ". Exception: " + e);
                e.printStackTrace(System.err);
            }
        }
        return null;
    }

    public static void main(String[] args) {
        AccessTokenClient accessTokenClient = new AccessTokenClient();
        String value = accessTokenClient.execute(args);
        if (value == null) {
            System.exit(1);
        }
    }
}