com.netflix.amazoncomponents.security.AmazonClientProvider.java Source code

Java tutorial

Introduction

Here is the source code for com.netflix.amazoncomponents.security.AmazonClientProvider.java

Source

/*
 * Copyright 2014 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.netflix.amazoncomponents.security;

import com.amazonaws.AmazonServiceException;
import com.amazonaws.AmazonWebServiceClient;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProviderChain;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.autoscaling.AmazonAutoScaling;
import com.amazonaws.services.autoscaling.AmazonAutoScalingClient;
import com.amazonaws.services.autoscaling.model.*;
import com.amazonaws.services.cloudwatch.AmazonCloudWatch;
import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient;
import com.amazonaws.services.ec2.AmazonEC2;
import com.amazonaws.services.ec2.AmazonEC2Client;
import com.amazonaws.services.ec2.model.*;
import com.amazonaws.services.ec2.model.Instance;
import com.amazonaws.services.elasticloadbalancing.AmazonElasticLoadBalancing;
import com.amazonaws.services.elasticloadbalancing.AmazonElasticLoadBalancingClient;
import com.amazonaws.services.elasticloadbalancing.model.DescribeLoadBalancersRequest;
import com.amazonaws.services.elasticloadbalancing.model.DescribeLoadBalancersResult;
import com.amazonaws.services.elasticloadbalancing.model.LoadBalancerDescription;
import com.amazonaws.services.route53.AmazonRoute53;
import com.amazonaws.services.route53.AmazonRoute53Client;
import com.amazonaws.services.simpleworkflow.AmazonSimpleWorkflow;
import com.amazonaws.services.simpleworkflow.AmazonSimpleWorkflowClient;
import com.amazonaws.services.sns.AmazonSNS;
import com.amazonaws.services.sns.AmazonSNSClient;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.netflix.amazoncomponents.data.AmazonObjectMapper;
import com.netflix.amazoncomponents.model.RetryCallback;
import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.entity.ContentType;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.util.EntityUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
 * Provider of Amazon Clients.
 *
 * @author Dan Woods
 */
public class AmazonClientProvider {

    private HttpClient httpClient;
    private ObjectMapper objectMapper;
    private RetryCallback retryCallback = new RetryCallback() {
        @Override
        public boolean doCall(Throwable t, int attempts) {
            if (attempts < 5) {
                try {
                    Thread.sleep(150);
                } catch (InterruptedException IGNORE) {
                    //
                }
                return true;
            } else {
                return false;
            }
        }
    };

    public AmazonClientProvider() {
        this((HttpClient) null);
    }

    public AmazonClientProvider(HttpClient httpClient) {
        this(httpClient, new AmazonObjectMapper());
    }

    public AmazonClientProvider(ObjectMapper objectMapper) {
        this(null, objectMapper);
    }

    public AmazonClientProvider(HttpClient httpClient, ObjectMapper objectMapper) {
        this.httpClient = httpClient;
        this.objectMapper = objectMapper;
    }

    public void setRetryCallback(RetryCallback retryCallback) {
        this.retryCallback = retryCallback;
    }

    public AmazonEC2 getAmazonEC2(AmazonCredentials amazonCredentials, String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonEC2.class, AmazonEC2Client.class, amazonCredentials, region);

    }

    public AmazonAutoScaling getAutoScaling(AmazonCredentials amazonCredentials, String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonAutoScaling.class, AmazonAutoScalingClient.class, amazonCredentials,
                region);
    }

    public AmazonRoute53 getAmazonRoute53(AmazonCredentials amazonCredentials, String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonRoute53.class, AmazonRoute53Client.class, amazonCredentials, region);
    }

    public AmazonElasticLoadBalancing getAmazonElasticLoadBalancing(AmazonCredentials amazonCredentials,
            String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonElasticLoadBalancing.class, AmazonElasticLoadBalancingClient.class,
                amazonCredentials, region);
    }

    public AmazonSimpleWorkflow getAmazonSimpleWorkflow(AmazonCredentials amazonCredentials, String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonSimpleWorkflow.class, AmazonSimpleWorkflowClient.class, amazonCredentials,
                region);
    }

    public AmazonSimpleWorkflow getAmazonSimpleWorkflow(AWSCredentialsProviderChain providerChain, String region) {
        if (providerChain == null) {
            throw new IllegalArgumentException("Provider chain cannot be null");
        }
        AmazonSimpleWorkflowClient client = new AmazonSimpleWorkflowClient(providerChain);
        if (region != null && region.length() > 0) {
            client.setRegion(Region.getRegion(Regions.fromName(region)));
        }
        return client;
    }

    public AmazonCloudWatch getAmazonCloudWatch(AmazonCredentials amazonCredentials, String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonCloudWatch.class, AmazonCloudWatchClient.class, amazonCredentials,
                region);
    }

    public AmazonSNS getAmazonSNS(AmazonCredentials amazonCredentials, String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonSNS.class, AmazonSNSClient.class, amazonCredentials, region);
    }

    public AmazonCloudWatch getCloudWatch(AmazonCredentials amazonCredentials, String region) {
        checkCredentials(amazonCredentials);
        return getThrottlingHandler(AmazonCloudWatch.class, AmazonCloudWatchClient.class, amazonCredentials,
                region);
    }

    protected <T extends AmazonWebServiceClient, U> U getThrottlingHandler(Class<U> interfaceKlazz, Class<T> impl,
            AmazonCredentials amazonCredentials, String region) {
        try {
            T delegate = getClient(impl, amazonCredentials, region);
            U client = (U) Proxy.newProxyInstance(getClass().getClassLoader(), new Class[] { interfaceKlazz },
                    new ThrottledAmazonClientInvocationHandler(delegate, retryCallback));
            if (!amazonCredentials.isEddaConfigured()) {
                return client;
            } else {
                return (U) Proxy.newProxyInstance(getClass().getClassLoader(), new Class[] { interfaceKlazz },
                        getInvocationHandler(client, delegate.getServiceName(), region, amazonCredentials));
            }
        } catch (Exception e) {
            throw new RuntimeException("Instantiation of client implementation failed!", e);
        }
    }

    protected <T extends AmazonWebServiceClient> T getClient(Class<T> impl, AmazonCredentials amazonCredentials,
            String region) throws IllegalAccessException, InvocationTargetException, InstantiationException,
            NoSuchMethodException {
        Constructor<T> constructor = impl.getConstructor(AWSCredentials.class);
        T delegate = constructor.newInstance(amazonCredentials.getCredentials());
        if (region != null && region.length() > 0) {
            delegate.setRegion(Region.getRegion(Regions.fromName(region)));
        }
        return delegate;
    }

    protected GeneralAmazonClientInvocationHandler getInvocationHandler(Object client, String serviceName,
            String region, AmazonCredentials amazonCredentials) {
        return new GeneralAmazonClientInvocationHandler(client, serviceName,
                String.format(amazonCredentials.getEdda(), region),
                this.httpClient == null ? new DefaultHttpClient() : this.httpClient, objectMapper);
    }

    private static void checkCredentials(AmazonCredentials amazonCredentials) {
        if (amazonCredentials == null) {
            throw new IllegalArgumentException("Credentials cannot be null");
        }
    }

    public static class ThrottledAmazonClientInvocationHandler implements InvocationHandler {

        private AmazonWebServiceClient delegate;
        private RetryCallback retryCallback;

        ThrottledAmazonClientInvocationHandler(AmazonWebServiceClient delegate, RetryCallback retryCallback) {
            this.delegate = delegate;
            this.retryCallback = retryCallback;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            return invoke(proxy, method, args, 1);
        }

        private Object invoke(Object proxy, Method method, Object[] args, int attempts) throws Throwable {
            try {
                return method.invoke(delegate, args);
            } catch (InvocationTargetException ex) {
                Throwable t = ex.getTargetException();
                if (t instanceof AmazonServiceException) {
                    AmazonServiceException e = (AmazonServiceException) t;
                    if ("RequestLimitExceeded".equals(e.getErrorCode())) {
                        boolean tryAgain = retryCallback.doCall(e, attempts);
                        if (tryAgain) {
                            return invoke(proxy, method, args, ++attempts);
                        }
                    }
                }
                throw t;
            }
        }
    }

    public static class GeneralAmazonClientInvocationHandler implements InvocationHandler {
        private final String edda;
        private final HttpClient httpClient;
        private final Object delegate;
        private final String serviceName;
        private final ObjectMapper objectMapper;

        public GeneralAmazonClientInvocationHandler(Object delegate, String serviceName, String edda,
                HttpClient httpClient, ObjectMapper objectMapper) {
            this.edda = edda;
            this.httpClient = httpClient;
            this.objectMapper = objectMapper;
            this.delegate = delegate;
            this.serviceName = serviceName;
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            try {
                Method thisMethod = this.getClass().getMethod(method.getName(),
                        args != null && args.length > 0 ? getClassArgs(args) : new Class[0]);
                return thisMethod.invoke(this, args);
            } catch (NoSuchMethodException e) {
                return method.invoke(delegate, args);
            }
        }

        static Class[] getClassArgs(Object[] args) {
            List<Class> classes = new ArrayList<>();
            for (Object object : args) {
                classes.add(object.getClass());
            }
            return classes.toArray(new Class[classes.size()]);
        }

        public DescribeLaunchConfigurationsResult describeLaunchConfigurations(
                DescribeLaunchConfigurationsRequest request) {
            return new DescribeLaunchConfigurationsResult()
                    .withLaunchConfigurations(describe(request, "launchConfigurationNames", "launchConfigurations",
                            LaunchConfiguration.class, new TypeReference<List<LaunchConfiguration>>() {
                            }));
        }

        public DescribeLaunchConfigurationsResult describeLaunchConfigurations() {
            return describeLaunchConfigurations(null);
        }

        public DescribeSecurityGroupsResult describeSecurityGroups(DescribeSecurityGroupsRequest request) {
            return new DescribeSecurityGroupsResult().withSecurityGroups(describe(request, "groupIds",
                    "securityGroups", SecurityGroup.class, new TypeReference<List<SecurityGroup>>() {
                    }));
        }

        public DescribeSecurityGroupsResult describeSecurityGroups() {
            return describeSecurityGroups(null);
        }

        public DescribeSubnetsResult describeSubnets(DescribeSubnetsRequest request) {
            return new DescribeSubnetsResult().withSubnets(
                    describe(request, "subnetIds", "subnets", Subnet.class, new TypeReference<List<Subnet>>() {
                    }));
        }

        public DescribeSubnetsResult describeSubnets() {
            return describeSubnets(null);
        }

        public DescribeAutoScalingGroupsResult describeAutoScalingGroups(DescribeAutoScalingGroupsRequest request) {
            return new DescribeAutoScalingGroupsResult()
                    .withAutoScalingGroups(describe(request, "autoScalingGroupNames", "autoScalingGroups",
                            AutoScalingGroup.class, new TypeReference<List<AutoScalingGroup>>() {
                            }));
        }

        public DescribeAutoScalingGroupsResult describeAutoScalingGroups() {
            return describeAutoScalingGroups(null);
        }

        public DescribeInstancesResult describeInstances() {
            return describeInstances(null);
        }

        public DescribeImagesResult describeImages() {
            return describeImages(null);
        }

        public DescribeImagesResult describeImages(DescribeImagesRequest request) {
            return new DescribeImagesResult().withImages(
                    describe(request, "imageIds", "images", Image.class, new TypeReference<List<Image>>() {
                    }));
        }

        public DescribeInstancesResult describeInstances(DescribeInstancesRequest request) {
            List<Instance> instances = describe(request, "instanceIds", "../view/instances", Instance.class,
                    new TypeReference<List<Instance>>() {
                    });
            Reservation reservation = new Reservation().withReservationId("1234").withInstances(instances);
            return new DescribeInstancesResult().withReservations(reservation);
        }

        public DescribeLoadBalancersResult describeLoadBalancers() {
            return describeLoadBalancers(null);
        }

        public DescribeLoadBalancersResult describeLoadBalancers(DescribeLoadBalancersRequest request) {
            List<LoadBalancerDescription> loadBalancerDescriptions = describe(request, "loadBalancerNames",
                    "loadBalancers", LoadBalancerDescription.class,
                    new TypeReference<List<LoadBalancerDescription>>() {
                    });
            return new DescribeLoadBalancersResult().withLoadBalancerDescriptions(loadBalancerDescriptions);
        }

        public <T> T describe(Object request, String idKey, final String object, final Class singleType,
                TypeReference<T> collectionType) {
            try {
                if (request != null) {
                    try {
                        Field field = request.getClass().getDeclaredField(idKey);
                        field.setAccessible(true);
                        List collection = (List) field.get(request);
                        List results = new ArrayList<>();
                        if (collection != null && collection.size() > 0) {
                            for (Object id : collection) {
                                if (id instanceof String) {
                                    results.add(objectMapper.readValue(getJson(object, (String) id), singleType));
                                }
                            }
                        } else {
                            results.addAll(
                                    (Collection) objectMapper.readValue(getJson(object, null), collectionType));
                        }
                        return (T) results;
                    } catch (NoSuchFieldException | IllegalAccessException | IOException e) {
                        throw new RuntimeException(e);
                    }
                } else {
                    try {
                        String json = getJson(object, null);
                        if (json == null) {
                            throw new RuntimeException(String.format(
                                    "Could not retrieve JSON Data from Edda host (%s) for object (%s).", edda,
                                    object));
                        }
                        return (T) objectMapper.readValue(json, collectionType);
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                }
            } catch (Exception e) {
                AmazonServiceException ex = new AmazonServiceException(
                        "400 Bad Request -- Edda could not find one of the managed objects requested.", e);
                ex.setStatusCode(400);
                ex.setServiceName(serviceName);
                ex.setErrorType(AmazonServiceException.ErrorType.Unknown);
                throw ex;
            }

        }

        private String getJson(String objectName, String key) throws IOException {
            String url = key != null ? edda + "/REST/v2/aws/" + objectName + "/" + key
                    : edda + "/REST/v2/aws/" + objectName + ";_expand";
            HttpGet get = new HttpGet(url);
            HttpResponse response = httpClient.execute(get);
            HttpEntity entity = response.getEntity();
            if (response.getStatusLine().getStatusCode() != HttpStatus.SC_OK) {
                return null;
            } else {
                String result = getStringFromInputStream(entity.getContent());
                EntityUtils.consume(entity);
                return result;
            }
        }

        private static String getStringFromInputStream(InputStream is) throws IOException {
            StringBuilder sb = new StringBuilder();
            String line;
            try (BufferedReader br = new BufferedReader(new InputStreamReader(is))) {
                while ((line = br.readLine()) != null) {
                    sb.append(line);
                }
            }
            return sb.toString();
        }

    }
}