com.noctarius.hazelcast.aws.HazelcastAwsDiscoveryStrategy.java Source code

Java tutorial

Introduction

Here is the source code for com.noctarius.hazelcast.aws.HazelcastAwsDiscoveryStrategy.java

Source

/*
 * Copyright (c) 2015, Christoph Engelbert (aka noctarius) and
 * contributors. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.noctarius.hazelcast.aws;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.Protocol;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSCredentialsProviderChain;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.EnvironmentVariableCredentialsProvider;
import com.amazonaws.auth.InstanceProfileCredentialsProvider;
import com.amazonaws.auth.SystemPropertiesCredentialsProvider;
import com.amazonaws.internal.StaticCredentialsProvider;
import com.amazonaws.services.ec2.AmazonEC2Client;
import com.amazonaws.services.ec2.model.DescribeInstancesRequest;
import com.amazonaws.services.ec2.model.DescribeInstancesResult;
import com.amazonaws.services.ec2.model.DescribeRegionsResult;
import com.amazonaws.services.ec2.model.Filter;
import com.amazonaws.services.ec2.model.Instance;
import com.amazonaws.services.ec2.model.Region;
import com.amazonaws.services.ec2.model.Reservation;
import com.amazonaws.services.ec2.model.Tag;
import com.hazelcast.config.NetworkConfig;
import com.hazelcast.config.properties.PropertyDefinition;
import com.hazelcast.logging.ILogger;
import com.hazelcast.nio.Address;
import com.hazelcast.spi.discovery.DiscoveryNode;
import com.hazelcast.spi.discovery.DiscoveryStrategy;
import com.hazelcast.spi.discovery.SimpleDiscoveryNode;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class HazelcastAwsDiscoveryStrategy implements DiscoveryStrategy {

    private static final String EC2_US_EAST_1_REGION = "us-east-1";
    private static final String EC2_US_EAST_1_ENDPOINT = "ec2.us-east-1.amazonaws.com";

    private static final String EC2_INSTANCE_STATE_NAME = "instance-state-name";
    private static final String EC2_SECURITY_GROUP_NAME = "group-name";

    private static final String HAZELCAST_SERVICE_PORT = "hazelcast-service-port";

    private final ILogger logger;
    private final AmazonEC2Client client;
    private final Map<String, Comparable> properties;

    public HazelcastAwsDiscoveryStrategy(ILogger logger, Map<String, Comparable> properties) {
        this.logger = logger;
        this.properties = properties;
        this.client = buildAmazonEC2Client();
    }

    public void start() {
    }

    public Iterable<DiscoveryNode> discoverNodes() {
        DescribeInstancesRequest instancesRequest = buildInstanceRequest();
        DescribeInstancesResult instances = client.describeInstances(instancesRequest);

        List<DiscoveryNode> DiscoveryNodes = new ArrayList<DiscoveryNode>(instances.getReservations().size());
        for (Reservation reservation : instances.getReservations()) {
            for (Instance instance : reservation.getInstances()) {
                DiscoveryNodes.add(buildDiscoveryNode(instance));
            }
        }

        return DiscoveryNodes;
    }

    public void destroy() {
        client.shutdown();
    }

    private DiscoveryNode buildDiscoveryNode(Instance instance) {
        String privateIpAddress = instance.getPrivateIpAddress();
        String publicIpAddress = instance.getPublicIpAddress();

        InetAddress privateAddress = mapAddress(privateIpAddress);
        InetAddress publicAddress = mapAddress(publicIpAddress);

        Map<String, Object> properties = mapTagsToProperties(instance);
        int port = getServicePort(properties);

        Address privateAddressInstance = new Address(privateAddress, port);
        Address publicAddressInstance = new Address(publicAddress, port);

        return new SimpleDiscoveryNode(privateAddressInstance, publicAddressInstance, properties);
    }

    private InetAddress mapAddress(String address) {
        if (address == null) {
            return null;
        }
        try {
            return InetAddress.getByName(address);
        } catch (UnknownHostException e) {
            logger.warning("Address '" + address + "' could not be resolved");
        }
        return null;
    }

    private Map<String, Object> mapTagsToProperties(Instance instance) {
        List<Tag> tags = instance.getTags();
        Map<String, Object> properties = new HashMap<String, Object>(tags.size());
        for (Tag tag : tags) {
            properties.put(tag.getKey(), tag.getValue());
        }
        return properties;
    }

    private int getServicePort(Map<String, Object> properties) {
        int port = NetworkConfig.DEFAULT_PORT;
        String servicePort = (String) properties.get(HAZELCAST_SERVICE_PORT);
        if (servicePort != null) {
            port = Integer.parseInt(servicePort);
        }
        return port;
    }

    private DescribeInstancesRequest buildInstanceRequest() {
        DescribeInstancesRequest instancesRequest = new DescribeInstancesRequest();

        configureFilter(instancesRequest, EC2_INSTANCE_STATE_NAME, "running", "pending");
        configureTag(instancesRequest);
        configureSecurityGroupName(instancesRequest);

        return instancesRequest;
    }

    private void configureSecurityGroupName(DescribeInstancesRequest instancesRequest) {
        String securityGroupName = getOrNull(AwsProperties.SECURITY_GROUP_NAME);
        if (securityGroupName != null) {
            configureFilter(instancesRequest, EC2_SECURITY_GROUP_NAME, securityGroupName);
        }
    }

    private void configureTag(DescribeInstancesRequest instancesRequest) {
        String tagKey = getOrNull(AwsProperties.TAG_KEY);
        String tagValue = getOrNull(AwsProperties.TAG_VALUE);
        if (tagKey != null && tagValue != null) {
            configureFilter(instancesRequest, "tag:" + tagKey, tagValue);
        }
    }

    private void configureFilter(DescribeInstancesRequest instancesRequest, String key, String... values) {
        instancesRequest.withFilters(new Filter(key).withValues(values));
    }

    private AmazonEC2Client buildAmazonEC2Client() {
        ClientConfiguration configuration = new ClientConfiguration();

        // Always set HTTPS as protocol, security first
        configuration.setProtocol(Protocol.HTTPS);

        // Configure proxy configuration
        configureProxy(configuration);

        // Configure authentication
        AWSCredentialsProvider credentialsProvider = buildCredentialsProvider();

        // Create WS client
        AmazonEC2Client client = new AmazonEC2Client(credentialsProvider, configuration);

        // Configure Amazon EC2 WS endpoint
        configureEndpoint(client);

        return client;
    }

    private void configureProxy(ClientConfiguration configuration) {
        String proxyHost = getOrNull(AwsProperties.PROXY_HOST);
        if (proxyHost == null) {
            return;
        }

        int proxyPort = getOrDefault(AwsProperties.PROXY_PORT, 80);
        String proxyUsername = getOrNull(AwsProperties.PROXY_USERNAME);
        String proxyPassword = getOrNull(AwsProperties.PROXY_PASSWORD);

        configuration.withProxyHost(proxyHost).setProxyPort(proxyPort);
        configuration.withProxyUsername(proxyUsername);
        configuration.withProxyPassword(proxyPassword);
    }

    private AWSCredentialsProvider buildCredentialsProvider() {
        String accessKey = getOrNull(AwsProperties.ACCESS_KEY);
        String secretKey = getOrNull(AwsProperties.SECRET_KEY);

        if (accessKey == null && secretKey == null) {
            return new AWSCredentialsProviderChain(new EnvironmentVariableCredentialsProvider(),
                    new SystemPropertiesCredentialsProvider(), new InstanceProfileCredentialsProvider());
        }

        return new AWSCredentialsProviderChain(
                new StaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)));
    }

    private Map<String, String> buildRegionLookup(AmazonEC2Client client) {
        Map<String, String> regionLookup = new HashMap<String, String>();

        DescribeRegionsResult regionsResult = client.describeRegions();
        for (Region region : regionsResult.getRegions()) {
            regionLookup.put(region.getRegionName(), region.getEndpoint());
        }
        return Collections.unmodifiableMap(regionLookup);
    }

    private void configureEndpoint(AmazonEC2Client client) {
        String region = getOrDefault(AwsProperties.REGION, EC2_US_EAST_1_REGION);

        // Set default endpoint for first request
        client.setEndpoint(EC2_US_EAST_1_ENDPOINT);

        if (EC2_US_EAST_1_REGION.equals(region)) {
            Map<String, String> regionLookup = buildRegionLookup(client);
            String endpoint = regionLookup.get(region);
            if (endpoint == null) {
                throw new RuntimeException("Amazon EC2 regions couldn't be retrieved");
            }
            client.setEndpoint(endpoint);
        }
    }

    private <T extends Comparable> T getOrNull(PropertyDefinition property) {
        return getOrDefault(property, null);
    }

    private <T extends Comparable> T getOrDefault(PropertyDefinition property, T defaultValue) {

        if (properties == null || property == null) {
            return defaultValue;
        }

        Comparable value = properties.get(property.key());
        if (value == null) {
            return defaultValue;
        }

        return (T) value;
    }
}