com.nike.cerberus.service.AuthenticationServiceTest.java Source code

Java tutorial

Introduction

Here is the source code for com.nike.cerberus.service.AuthenticationServiceTest.java

Source

/*
 * Copyright (c) 2017 Nike, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package com.nike.cerberus.service;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nike.cerberus.auth.connector.AuthConnector;
import com.nike.cerberus.aws.KmsClientFactory;
import com.nike.cerberus.dao.AwsIamRoleDao;
import com.nike.cerberus.dao.SafeDepositBoxDao;
import com.nike.cerberus.domain.IamPrincipalCredentials;
import com.nike.cerberus.record.AwsIamRoleKmsKeyRecord;
import com.nike.cerberus.record.AwsIamRoleRecord;
import com.nike.cerberus.record.SafeDepositBoxRoleRecord;
import com.nike.cerberus.security.VaultAuthPrincipal;
import com.nike.cerberus.server.config.CmsConfig;
import com.nike.cerberus.util.AwsIamRoleArnParser;
import com.nike.cerberus.util.DateTimeSupplier;
import com.nike.vault.client.VaultAdminClient;
import com.nike.vault.client.model.VaultAuthResponse;
import org.apache.commons.lang3.RandomStringUtils;
import org.assertj.core.util.Lists;
import org.assertj.core.util.Sets;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;

import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.HashMap;
import java.util.List;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;

import static com.nike.cerberus.service.AuthenticationService.LOOKUP_SELF_POLICY;
import static com.nike.cerberus.util.AwsIamRoleArnParser.AWS_IAM_ROLE_ARN_TEMPLATE;
import static org.junit.Assert.assertEquals;

import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;

/**
 * Tests the AuthenticationService class
 */
public class AuthenticationServiceTest {

    @Mock
    private SafeDepositBoxDao safeDepositBoxDao;

    @Mock
    private AwsIamRoleDao awsIamRoleDao;

    @Mock
    private AuthConnector authConnector;

    @Mock
    private KmsService kmsService;

    @Mock
    private KmsClientFactory kmsClientFactory;

    @Mock
    private VaultAdminClient vaultAdminClient;

    @Mock
    private VaultPolicyService vaultPolicyService;

    private ObjectMapper objectMapper;

    @Mock
    private DateTimeSupplier dateTimeSupplier;

    @Mock
    private AwsIamRoleArnParser awsIamRoleArnParser;

    private AuthenticationService authenticationService;

    @Before
    public void setup() {
        initMocks(this);
        objectMapper = CmsConfig.configureObjectMapper();
        authenticationService = new AuthenticationService(safeDepositBoxDao, awsIamRoleDao, authConnector,
                kmsService, kmsClientFactory, vaultAdminClient, vaultPolicyService, objectMapper, "foo",
                dateTimeSupplier, awsIamRoleArnParser);
    }

    @Test
    public void tests_that_generateCommonVaultPrincipalAuthMetadata_contains_expected_fields() {

        String principalArn = "principal arn";
        String region = "region";

        Map<String, String> result = authenticationService.generateCommonVaultPrincipalAuthMetadata(principalArn,
                region);

        assertTrue(result.containsKey(VaultAuthPrincipal.METADATA_KEY_USERNAME));
        assertEquals(principalArn, result.get(VaultAuthPrincipal.METADATA_KEY_USERNAME));

        assertTrue(result.containsKey(VaultAuthPrincipal.METADATA_KEY_AWS_REGION));
        assertEquals(region, result.get(VaultAuthPrincipal.METADATA_KEY_AWS_REGION));

        assertTrue(result.containsKey(VaultAuthPrincipal.METADATA_KEY_GROUPS));

        assertTrue(result.containsKey(VaultAuthPrincipal.METADATA_KEY_IS_ADMIN));
    }

    @Test
    public void test_that_getKeyId_only_validates_kms_policy_one_time_within_interval() {

        String principalArn = "principal arn";
        String region = "region";
        String iamRoleId = "iam role id";
        String kmsKeyId = "kms id";
        String cmkId = "key id";

        // ensure that validate interval is passed
        OffsetDateTime dateTime = OffsetDateTime.of(2016, 1, 1, 1, 1, 1, 1, ZoneOffset.UTC);
        OffsetDateTime now = OffsetDateTime.now();

        IamPrincipalCredentials iamPrincipalCredentials = new IamPrincipalCredentials();
        iamPrincipalCredentials.setIamPrincipalArn(principalArn);
        iamPrincipalCredentials.setRegion(region);

        AwsIamRoleRecord awsIamRoleRecord = new AwsIamRoleRecord().setAwsIamRoleArn(principalArn);
        awsIamRoleRecord.setAwsIamRoleArn(principalArn);
        awsIamRoleRecord.setId(iamRoleId);
        when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.of(awsIamRoleRecord));

        AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord = new AwsIamRoleKmsKeyRecord();
        awsIamRoleKmsKeyRecord.setId(kmsKeyId);
        awsIamRoleKmsKeyRecord.setAwsKmsKeyId(cmkId);
        awsIamRoleKmsKeyRecord.setLastValidatedTs(dateTime);

        when(awsIamRoleDao.getKmsKey(iamRoleId, region)).thenReturn(Optional.of(awsIamRoleKmsKeyRecord));

        when(dateTimeSupplier.get()).thenReturn(now);

        String result = authenticationService.getKeyId(iamPrincipalCredentials);

        // verify validate is called once interval has passed
        assertEquals(cmkId, result);
        verify(kmsService, times(1)).validatePolicy(awsIamRoleKmsKeyRecord, principalArn);
    }

    @Test
    public void test_that_buildCompleteSetOfPolicies_returns_all_policies() {

        String accountId = "0000000000";
        String roleName = "role/path";
        String principalArn = String.format("arn:aws:iam::%s:instance-profile/%s", accountId, roleName);

        String roleArn = String.format(AWS_IAM_ROLE_ARN_TEMPLATE, accountId, roleName);
        when(awsIamRoleArnParser.isRoleArn(principalArn)).thenReturn(false);
        when(awsIamRoleArnParser.convertPrincipalArnToRoleArn(principalArn)).thenReturn(roleArn);

        String principalPolicy1 = "principal policy 1";
        String principalPolicy2 = "principal policy 2";
        String principalArnSdb1 = "principal arn sdb 1";
        String principalArnSdb2 = "principal arn sdb 2";
        SafeDepositBoxRoleRecord principalArnRecord1 = new SafeDepositBoxRoleRecord().setRoleName(roleName)
                .setSafeDepositBoxName(principalArnSdb1);
        SafeDepositBoxRoleRecord principalArnRecord2 = new SafeDepositBoxRoleRecord().setRoleName(roleName)
                .setSafeDepositBoxName(principalArnSdb2);
        List<SafeDepositBoxRoleRecord> principalArnRecords = Lists.newArrayList(principalArnRecord1,
                principalArnRecord2);
        when(safeDepositBoxDao.getIamRoleAssociatedSafeDepositBoxRoles(principalArn))
                .thenReturn(principalArnRecords);
        when(vaultPolicyService.buildPolicyName(principalArnSdb1, roleName)).thenReturn(principalPolicy1);
        when(vaultPolicyService.buildPolicyName(principalArnSdb2, roleName)).thenReturn(principalPolicy2);

        String rolePolicy = "role policy";
        String roleArnSdb = "role arn sdb";
        SafeDepositBoxRoleRecord roleArnRecord = new SafeDepositBoxRoleRecord().setRoleName(roleName)
                .setSafeDepositBoxName(roleArnSdb);
        List<SafeDepositBoxRoleRecord> roleArnRecords = Lists.newArrayList(roleArnRecord);
        when(safeDepositBoxDao.getIamRoleAssociatedSafeDepositBoxRoles(roleArn)).thenReturn(roleArnRecords);
        when(vaultPolicyService.buildPolicyName(roleArnSdb, roleName)).thenReturn(rolePolicy);

        List<String> expectedPolicies = Lists.newArrayList(principalPolicy1, principalPolicy2, rolePolicy,
                LOOKUP_SELF_POLICY);
        Set<String> expected = Sets.newHashSet(expectedPolicies);
        Set<String> result = authenticationService.buildCompleteSetOfPolicies(principalArn);

        assertEquals(expected, result);
    }

    @Test
    public void test_that_findIamRoleAssociatedWithSdb_returns_first_matching_iam_role_record_if_found() {

        String principalArn = "principal arn";
        AwsIamRoleRecord awsIamRoleRecord = mock(AwsIamRoleRecord.class);
        when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.of(awsIamRoleRecord));

        Optional<AwsIamRoleRecord> result = authenticationService.findIamRoleAssociatedWithSdb(principalArn);

        assertEquals(awsIamRoleRecord, result.get());
    }

    @Test
    public void test_that_findIamRoleAssociatedWithSdb_returns_generic_role_when_iam_principal_not_found() {

        String accountId = "0000000000";
        String roleName = "role/path";
        String principalArn = String.format("arn:aws:iam::%s:instance-profile/%s", accountId, roleName);
        String roleArn = String.format(AWS_IAM_ROLE_ARN_TEMPLATE, accountId, roleName);

        AwsIamRoleRecord awsIamRoleRecord = mock(AwsIamRoleRecord.class);
        when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.empty());
        when(awsIamRoleDao.getIamRole(roleArn)).thenReturn(Optional.of(awsIamRoleRecord));

        when(awsIamRoleArnParser.isRoleArn(principalArn)).thenReturn(false);
        when(awsIamRoleArnParser.convertPrincipalArnToRoleArn(principalArn)).thenReturn(roleArn);

        Optional<AwsIamRoleRecord> result = authenticationService.findIamRoleAssociatedWithSdb(principalArn);

        assertEquals(awsIamRoleRecord, result.get());
    }

    @Test
    public void test_that_findIamRoleAssociatedWithSdb_returns_empty_optional_when_roles_not_found() {

        String accountId = "0000000000";
        String roleName = "role/path";
        String principalArn = String.format("arn:aws:iam::%s:instance-profile/%s", accountId, roleName);
        String roleArn = String.format("arn:aws:iam::%s:role/%s", accountId, roleName);

        when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.empty());
        when(awsIamRoleDao.getIamRole(roleArn)).thenReturn(Optional.empty());

        when(awsIamRoleArnParser.isRoleArn(principalArn)).thenReturn(false);
        when(awsIamRoleArnParser.convertPrincipalArnToRoleArn(principalArn)).thenReturn(roleArn);

        Optional<AwsIamRoleRecord> result = authenticationService.findIamRoleAssociatedWithSdb(principalArn);

        assertFalse(result.isPresent());
    }

    @Test
    public void tests_that_validateAuthPayloadSizeAndTruncateIfLargerThanMaxKmsSupportedSize_returns_the_original_payload_if_the_size_can_be_encrypted_by_kms()
            throws JsonProcessingException {
        VaultAuthResponse response = new VaultAuthResponse().setClientToken(UUID.randomUUID().toString())
                .setLeaseDuration(3600).setMetadata(new HashMap<>()).setPolicies(new HashSet<>())
                .setRenewable(false);

        byte[] serializedAuth = new ObjectMapper().writeValueAsBytes(response);

        byte[] actual = authenticationService
                .validateAuthPayloadSizeAndTruncateIfLargerThanMaxKmsSupportedSize(serializedAuth, response, "foo");

        assertEquals(serializedAuth, actual);
    }

    @Test
    public void tests_that_validateAuthPayloadSizeAndTruncateIfLargerThanMaxKmsSupportedSize_returns_a_truncated_payload_if_the_size_cannot_be_encrypted_by_kms()
            throws JsonProcessingException {
        Map<String, String> meta = new HashMap<>();
        Set<String> policies = new HashSet<>();
        for (int i = 0; i < 100; i++) {
            policies.add(RandomStringUtils.random(25));
        }

        VaultAuthResponse response = new VaultAuthResponse().setClientToken(UUID.randomUUID().toString())
                .setLeaseDuration(3600).setMetadata(meta).setPolicies(policies).setRenewable(false);

        byte[] serializedAuth = new ObjectMapper().writeValueAsBytes(response);
        assertTrue(serializedAuth.length > AuthenticationService.KMS_SIZE_LIMIT);

        byte[] actual = authenticationService
                .validateAuthPayloadSizeAndTruncateIfLargerThanMaxKmsSupportedSize(serializedAuth, response, "foo");

        assertNotEquals(serializedAuth, actual);
        assertTrue(actual.length < AuthenticationService.KMS_SIZE_LIMIT);
    }
}