com.remediatetheflag.global.utils.AWSHelper.java Source code

Java tutorial

Introduction

Here is the source code for com.remediatetheflag.global.utils.AWSHelper.java

Source

/*
 *  
 * REMEDIATE THE FLAG
 * Copyright 2018 - Andrea Scaduto 
 * remediatetheflag@gmail.com
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
 */
package com.remediatetheflag.global.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.cloudwatch.AmazonCloudWatch;
import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder;
import com.amazonaws.services.cloudwatch.model.Dimension;
import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsRequest;
import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsResult;
import com.amazonaws.services.ec2.AmazonEC2;
import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
import com.amazonaws.services.ec2.model.DescribeInstancesRequest;
import com.amazonaws.services.ec2.model.DescribeInstancesResult;
import com.amazonaws.services.ecs.AmazonECS;
import com.amazonaws.services.ecs.AmazonECSClientBuilder;
import com.amazonaws.services.ecs.model.ContainerDefinition;
import com.amazonaws.services.ecs.model.ContainerOverride;
import com.amazonaws.services.ecs.model.DeregisterTaskDefinitionRequest;
import com.amazonaws.services.ecs.model.DeregisterTaskDefinitionResult;
import com.amazonaws.services.ecs.model.DescribeClustersRequest;
import com.amazonaws.services.ecs.model.DescribeClustersResult;
import com.amazonaws.services.ecs.model.DescribeContainerInstancesRequest;
import com.amazonaws.services.ecs.model.DescribeContainerInstancesResult;
import com.amazonaws.services.ecs.model.DescribeTasksRequest;
import com.amazonaws.services.ecs.model.DescribeTasksResult;
import com.amazonaws.services.ecs.model.Failure;
import com.amazonaws.services.ecs.model.KernelCapabilities;
import com.amazonaws.services.ecs.model.KeyValuePair;
import com.amazonaws.services.ecs.model.LinuxParameters;
import com.amazonaws.services.ecs.model.ListTasksRequest;
import com.amazonaws.services.ecs.model.ListTasksResult;
import com.amazonaws.services.ecs.model.LogConfiguration;
import com.amazonaws.services.ecs.model.NetworkBinding;
import com.amazonaws.services.ecs.model.PortMapping;
import com.amazonaws.services.ecs.model.RegisterTaskDefinitionRequest;
import com.amazonaws.services.ecs.model.RegisterTaskDefinitionResult;
import com.amazonaws.services.ecs.model.RunTaskRequest;
import com.amazonaws.services.ecs.model.RunTaskResult;
import com.amazonaws.services.ecs.model.StopTaskRequest;
import com.amazonaws.services.ecs.model.Task;
import com.amazonaws.services.ecs.model.TaskOverride;
import com.amazonaws.services.logs.AWSLogs;
import com.amazonaws.services.logs.AWSLogsClientBuilder;
import com.amazonaws.services.logs.model.CreateLogGroupRequest;
import com.remediatetheflag.global.model.RTFECSContainerTask;
import com.remediatetheflag.global.model.RTFECSTaskDefinition;
import com.remediatetheflag.global.model.RTFInstanceReservation;
import com.remediatetheflag.global.model.User;

public class AWSHelper {

    private static Logger logger = LoggerFactory.getLogger(AWSHelper.class);

    public RTFInstanceReservation createRTFInstance(LaunchStrategy strategy) {
        return strategy.launch();
    }

    public void terminateTask(String taskArn) {
        String regionFromArn = taskArn.split(":")[3];
        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(regionFromArn)
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
        try {
            StopTaskRequest stopTaskRequest = new StopTaskRequest();
            stopTaskRequest.withCluster(RTFConfig.getExercisesCluster()).withTask(taskArn);
            client.stopTask(stopTaskRequest);
        } catch (Exception e) {
            logger.error("Error terminating task " + e.getMessage());
        }
    }

    public void terminateTask(RTFECSContainerTask ecsInstance) {
        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(ecsInstance.getRegion())
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
        try {
            StopTaskRequest stopTaskRequest = new StopTaskRequest();
            stopTaskRequest.withCluster(RTFConfig.getExercisesCluster()).withTask(ecsInstance.getTaskArn());
            client.stopTask(stopTaskRequest);
        } catch (Exception e) {
            logger.error("Error terminating task " + e.getMessage());
        }
    }

    public Integer getClusterContainerInstances(Region region) {
        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(region.getName())
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
        DescribeClustersRequest request = new DescribeClustersRequest();
        request.withClusters(RTFConfig.getExercisesCluster());
        logger.debug("Requesting number of cluster running instances for region " + region.getName() + " cluster "
                + RTFConfig.getExercisesCluster());
        try {
            DescribeClustersResult response = client.describeClusters(request);
            return response.getClusters().get(0).getRegisteredContainerInstancesCount();
        } catch (Exception e) {
            logger.error("Error getClusterContainerInstances for region " + region.getName() + " due to:\n"
                    + e.getMessage());
            return 0;
        }
    }

    public Double getClusterMemoryReservation(Region region) {
        AmazonCloudWatch client = AmazonCloudWatchClientBuilder.standard().withRegion(region.getName())
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();

        Dimension dimension = new Dimension();
        dimension.setName("ClusterName");
        dimension.setValue(RTFConfig.getExercisesCluster());
        Date date = new Date();
        Calendar cal = Calendar.getInstance();
        cal.setTime(date);
        cal.add(Calendar.MINUTE, -5);
        GetMetricStatisticsRequest request = new GetMetricStatisticsRequest().withMetricName("MemoryReservation")
                .withDimensions(dimension).withPeriod(60).withStartTime(cal.getTime()).withEndTime(date)
                .withStatistics("Average").withNamespace("AWS/ECS");
        try {
            logger.debug("Requesting memory reservation for region " + region.getName() + " cluster "
                    + RTFConfig.getExercisesCluster());

            GetMetricStatisticsResult response = client.getMetricStatistics(request);
            if (response.getDatapoints().isEmpty())
                return 0.0;
            return response.getDatapoints().get(0).getAverage();
        } catch (Exception e) {
            logger.error("Error getClusterContainerInstances for memory reservation in region " + region.getName()
                    + " due to:\n" + e.getMessage());
            return 0.0;
        }
    }

    public String createECSTaskDefinition(RTFECSTaskDefinition taskDef, User user) {

        createExerciseLogGroup(taskDef.getRegion(), Constants.EXERCISE_LOG_GROUP);

        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(taskDef.getRegion())
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();

        RegisterTaskDefinitionRequest request = new RegisterTaskDefinitionRequest();

        ArrayList<PortMapping> portMappings = new ArrayList<PortMapping>();
        portMappings.add(new PortMapping().withContainerPort(8080).withHostPort(0).withProtocol("tcp"));
        portMappings.add(new PortMapping().withContainerPort(3389).withHostPort(0).withProtocol("tcp"));

        LinuxParameters linuxParameters = new LinuxParameters();
        KernelCapabilities capabilities = new KernelCapabilities();
        capabilities.setAdd(Arrays.asList("NET_ADMIN"));
        linuxParameters.setCapabilities(capabilities);

        LogConfiguration logConfiguration = new LogConfiguration();
        logConfiguration.setLogDriver("awslogs");
        Map<String, String> options = new HashMap<String, String>();
        options.put("awslogs-group", Constants.EXERCISE_LOG_GROUP);
        options.put("awslogs-region", taskDef.getRegion().getName().toLowerCase());
        logConfiguration.setOptions(options);

        final ContainerDefinition def = new ContainerDefinition().withName(taskDef.getContainerName())
                .withImage(taskDef.getRepositoryImageUrl()).withMemoryReservation(taskDef.getSoftMemoryLimit())
                .withMemory(taskDef.getHardMemoryLimit()).withPortMappings(portMappings)
                .withLinuxParameters(linuxParameters).withLogConfiguration(logConfiguration).withEssential(true);

        request.setContainerDefinitions(Arrays.asList(def));
        request.setFamily(taskDef.getTaskDefinitionName());
        request.setNetworkMode("bridge");

        try {
            RegisterTaskDefinitionResult response = client.registerTaskDefinition(request);
            logger.debug("# ECS Task Definition " + taskDef.getTaskDefinitionName() + " created for user "
                    + user.getIdUser() + " in region " + taskDef.getRegion());
            return response.getTaskDefinition().getTaskDefinitionArn();
        } catch (Exception e) {
            logger.debug(
                    "# ECS Task Definition " + taskDef.getTaskDefinitionName() + " COULD NOT BE created for user "
                            + user.getIdUser() + " in region " + taskDef.getRegion() + "\n" + e.getMessage());
            return null;
        }
    }

    public Boolean removeTaskDefinitionInRegion(String taskDefinitionArn, Regions region) {
        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(region)
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
        try {
            DeregisterTaskDefinitionRequest request = new DeregisterTaskDefinitionRequest()
                    .withTaskDefinition(taskDefinitionArn);
            DeregisterTaskDefinitionResult result = client.deregisterTaskDefinition(request);
            return result.getSdkHttpMetadata().getHttpStatusCode() == 200;
        } catch (Exception e) {
            logger.warn(
                    "# ECS TaskDefinition " + taskDefinitionArn + " could not be deregistered " + e.getMessage());
            return false;
        }
    }

    protected RTFECSContainerTask createInstance(String clusterName, String instanceName, String password,
            RTFECSTaskDefinition taskDef, Integer duration, User user) {

        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(taskDef.getRegion())
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
        TaskOverride overrides = new TaskOverride();
        List<ContainerOverride> containerOverrides = new LinkedList<ContainerOverride>();
        ContainerOverride co = new ContainerOverride();
        List<KeyValuePair> environment = new LinkedList<KeyValuePair>();
        KeyValuePair kv = new KeyValuePair();
        kv.setName(Constants.ENV_USR_PWD);
        kv.setValue(password);
        environment.add(kv);
        co.setEnvironment(environment);
        co.setName(taskDef.getContainerName());
        containerOverrides.add(co);
        overrides.setContainerOverrides(containerOverrides);
        RunTaskRequest request = new RunTaskRequest().withCluster(clusterName)
                .withTaskDefinition(taskDef.getTaskDefinitionArn()).withOverrides(overrides);
        logger.debug("# ECS Requesting Task " + instanceName + " for user " + user.getIdUser()
                + " with task definition: " + taskDef.getTaskDefinitionArn() + " on cluster: " + clusterName
                + " on region " + taskDef.getRegion().getName());
        try {
            RunTaskResult response = client.runTask(request);
            String failureReason = "";
            if (response.getTasks().isEmpty()) {
                for (Failure failure : response.getFailures()) {
                    failureReason += "\n" + failure.getReason();
                }
                logger.error("Task creation failed due to: \n" + failureReason);
                return null;
            }
            Task task = response.getTasks().get(0);
            RTFECSContainerTask rtfInstance = new RTFECSContainerTask();
            rtfInstance.setCluster(task.getClusterArn());
            rtfInstance.setTaskArn(task.getTaskArn());
            rtfInstance.setIdContainerInstance(task.getContainerInstanceArn());
            rtfInstance.setName(instanceName);
            rtfInstance.setRegion(taskDef.getRegion());
            rtfInstance.setUser(user);
            rtfInstance.setCreateTime(task.getCreatedAt());
            rtfInstance.setStatus(Constants.STATUS_PENDING);
            logger.debug("# ECS Task " + instanceName + " created for user " + user.getIdUser()
                    + " with task definition: " + taskDef.getTaskDefinitionName() + " start: "
                    + rtfInstance.getCreateTime());
            return rtfInstance;
        } catch (Exception e) {
            logger.warn("# ECS Task " + instanceName + " could not be created for user " + user.getIdUser() + " "
                    + e.getMessage());
            return null;
        }
    }

    public List<String> getRunningECSTasks(List<Region> activeRegions) {
        LinkedList<String> list = new LinkedList<String>();
        for (Region region : activeRegions) {
            logger.debug("Enumerating running tasks on cluster " + RTFConfig.getExercisesCluster() + " for region "
                    + region.getName());
            AmazonECS client = AmazonECSClientBuilder.standard().withRegion(Regions.fromName(region.getName()))
                    .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
            ListTasksRequest request = new ListTasksRequest().withCluster(RTFConfig.getExercisesCluster());
            try {
                ListTasksResult response = client.listTasks(request);
                list.addAll(response.getTaskArns());
            } catch (Exception e) {
                logger.error("Error getRunningECSTasks for region " + region + " due to:\n" + e.getMessage());
            }
        }
        return list;
    }

    public Date getRunningECSTaskStartTime(String taskArn) {
        String regionFromArn = taskArn.split(":")[3];
        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(regionFromArn)
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
        DescribeTasksRequest request = new DescribeTasksRequest().withCluster(RTFConfig.getExercisesCluster())
                .withTasks(taskArn);
        try {
            DescribeTasksResult response = client.describeTasks(request);
            return response.getTasks().get(0).getCreatedAt();
        } catch (Exception e) {
            logger.error("Could not get creation time for task arn " + taskArn);
            return null;
        }
    }

    public RTFInstanceReservation pollReservation(RTFInstanceReservation reservation) {

        if (null == reservation.getEcs()) {
            reservation.setError(true);
            reservation.setFulfilled(false);
            reservation.setWaitSeconds(0);
            return reservation;
        }
        AmazonECS client = AmazonECSClientBuilder.standard().withRegion(reservation.getEcs().getRegion())
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();

        DescribeTasksRequest request = new DescribeTasksRequest().withCluster(RTFConfig.getExercisesCluster())
                .withTasks(reservation.getEcs().getTaskArn());
        try {
            DescribeTasksResult response = client.describeTasks(request);

            if (response.getTasks().isEmpty()) {
                reservation.setError(true);
                reservation.setFulfilled(true);
                reservation.setWaitSeconds(0);
                return reservation;
            }

            Integer rdpPort = -1;
            Integer httpPort = -1;

            Task task = response.getTasks().get(0);

            if (task.getLastStatus().equalsIgnoreCase(Constants.AWS_ECS_STATUS_STOPPED)) {
                reservation.setError(true);
                reservation.setFulfilled(true);
                reservation.setWaitSeconds(0);
                return reservation;
            }

            List<NetworkBinding> nb = task.getContainers().get(0).getNetworkBindings();
            if (nb.size() > 2) {
                logger.warn("More than two port bindings, only RDP 3389 and HTTP 8080 will be mapped");
            }
            for (NetworkBinding b : nb) {
                if (b.getContainerPort().equals(3389)) {
                    rdpPort = b.getHostPort();
                } else if (b.getContainerPort().equals(8080)) {
                    httpPort = b.getHostPort();
                }
            }
            if (rdpPort == -1 || httpPort == -1) {
                reservation.setFulfilled(false);
                reservation.setError(false);
                reservation.setWaitSeconds(10);
                return reservation;
            }
            reservation.getEcs().setHttpPort(httpPort);
            reservation.getEcs().setRdpPort(rdpPort);

            String containerInstanceId = task.getContainerInstanceArn();
            DescribeContainerInstancesRequest containerRequest = new DescribeContainerInstancesRequest()
                    .withCluster(RTFConfig.getExercisesCluster()).withContainerInstances(containerInstanceId);
            DescribeContainerInstancesResult containerResponse = client
                    .describeContainerInstances(containerRequest);

            if (containerResponse.getContainerInstances().isEmpty()) {
                reservation.setFulfilled(false);
                reservation.setError(true);
                reservation.setWaitSeconds(0);
                return reservation;
            }
            String ec2InstanceId = containerResponse.getContainerInstances().get(0).getEc2InstanceId();

            AmazonEC2 ec2 = AmazonEC2ClientBuilder.standard().withRegion(reservation.getEcs().getRegion())
                    .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
            DescribeInstancesRequest instanceRequest = new DescribeInstancesRequest()
                    .withInstanceIds(ec2InstanceId);
            DescribeInstancesResult instanceInstances = ec2.describeInstances(instanceRequest);

            if (instanceInstances.getReservations().isEmpty()
                    || instanceInstances.getReservations().get(0).getInstances().isEmpty()) {
                reservation.setFulfilled(false);
                reservation.setError(true);
                reservation.setWaitSeconds(0);
                return reservation;
            }

            String ipAddress = instanceInstances.getReservations().get(0).getInstances().get(0)
                    .getPrivateIpAddress();
            reservation.getEcs().setIpAddress(ipAddress);
            reservation.setFulfilled(true);
            reservation.setError(false);
            reservation.setWaitSeconds(0);
            return reservation;

        } catch (Exception e) {
            logger.error("Error pollReservation " + e.getMessage());
            reservation.setError(true);
            reservation.setFulfilled(false);
            reservation.setWaitSeconds(0);
            return reservation;
        }
    }

    public void createExerciseLogGroup(Regions region, String exerciseLogGroup) {
        AWSLogs logs = AWSLogsClientBuilder.standard().withRegion(region)
                .withCredentials(new DefaultAWSCredentialsProviderChain()).build();
        try {
            CreateLogGroupRequest logGroupRequest = new CreateLogGroupRequest(exerciseLogGroup);
            logs.createLogGroup(logGroupRequest);
            logger.debug("Log group created");
        } catch (Exception e) {
            logger.debug("Log group already existing.");

        }
    }
}