Added folder with CDK code and app. Updated config py file to be compatible with all temp folders needed for read only file systems
36574ae
import os | |
import json # You might still need json if loading task_definition.json | |
from typing import List, Dict, Any | |
from aws_cdk import ( | |
Stack, | |
CfnTag, # <-- Import CfnTag directly | |
CfnOutput, # <-- Import CfnOutput directly | |
Duration, | |
RemovalPolicy, | |
SecretValue, | |
aws_ec2 as ec2, | |
aws_ecr as ecr, | |
aws_s3 as s3, | |
aws_ecs as ecs, | |
aws_iam as iam, | |
aws_codebuild as codebuild, | |
aws_cognito as cognito, | |
aws_secretsmanager as secretsmanager, | |
aws_cloudfront as cloudfront, | |
aws_cloudfront_origins as origins, | |
aws_elasticloadbalancingv2 as elbv2, | |
aws_logs as logs, | |
aws_wafv2 as wafv2, | |
aws_dynamodb as dynamodb # Import the DynamoDB module | |
) | |
from constructs import Construct | |
from cdk_config import CDK_PREFIX, VPC_NAME, AWS_MANAGED_TASK_ROLES_LIST, GITHUB_REPO_USERNAME, GITHUB_REPO_NAME, GITHUB_REPO_BRANCH, ECS_TASK_MEMORY_SIZE, ECS_TASK_CPU_SIZE, CUSTOM_HEADER, CUSTOM_HEADER_VALUE, AWS_REGION, CLOUDFRONT_GEO_RESTRICTION, DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS, GRADIO_SERVER_PORT, PUBLIC_SUBNETS_TO_USE, PUBLIC_SUBNET_CIDR_BLOCKS, PUBLIC_SUBNET_AVAILABILITY_ZONES, PRIVATE_SUBNETS_TO_USE, PRIVATE_SUBNET_CIDR_BLOCKS, PRIVATE_SUBNET_AVAILABILITY_ZONES, CODEBUILD_PROJECT_NAME, ECS_SECURITY_GROUP_NAME, ALB_NAME_SECURITY_GROUP_NAME, ALB_NAME, COGNITO_USER_POOL_NAME, COGNITO_USER_POOL_CLIENT_NAME, COGNITO_USER_POOL_CLIENT_SECRET_NAME, FARGATE_TASK_DEFINITION_NAME, ECS_SERVICE_NAME, WEB_ACL_NAME, CLOUDFRONT_DISTRIBUTION_NAME, ECS_TASK_ROLE_NAME, ALB_TARGET_GROUP_NAME, S3_LOG_CONFIG_BUCKET_NAME, S3_OUTPUT_BUCKET_NAME, ACM_CERTIFICATE_ARN, CLUSTER_NAME, CODEBUILD_ROLE_NAME, ECS_TASK_EXECUTION_ROLE_NAME, ECR_CDK_REPO_NAME, ECS_LOG_GROUP_NAME, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, TASK_DEFINITION_FILE_LOCATION, EXISTING_IGW_ID, SINGLE_NAT_GATEWAY_ID, NAT_GATEWAY_NAME, COGNITO_USER_POOL_DOMAIN_PREFIX, COGNITO_REDIRECTION_URL, AWS_ACCOUNT_ID, ECS_USE_FARGATE_SPOT, ECS_READ_ONLY_FILE_SYSTEM, USE_CLOUDFRONT, LOAD_BALANCER_WEB_ACL_NAME | |
from cdk_functions import create_subnets, create_web_acl_with_common_rules, add_custom_policies, add_alb_https_listener_with_cert, create_nat_gateway # Only keep CDK-native functions | |
def _get_env_list(env_var_name: str) -> List[str]: | |
"""Parses a comma-separated environment variable into a list of strings.""" | |
value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","") | |
if not value: | |
return [] | |
# Split by comma and filter out any empty strings that might result from extra commas | |
return [s.strip() for s in value.split(',') if s.strip()] | |
# 1. Try to load CIDR/AZs from environment variables | |
if PUBLIC_SUBNETS_TO_USE: PUBLIC_SUBNETS_TO_USE = _get_env_list(PUBLIC_SUBNETS_TO_USE) | |
if PRIVATE_SUBNETS_TO_USE: PRIVATE_SUBNETS_TO_USE = _get_env_list(PRIVATE_SUBNETS_TO_USE) | |
if PUBLIC_SUBNET_CIDR_BLOCKS: PUBLIC_SUBNET_CIDR_BLOCKS = _get_env_list("PUBLIC_SUBNET_CIDR_BLOCKS") | |
if PUBLIC_SUBNET_AVAILABILITY_ZONES: PUBLIC_SUBNET_AVAILABILITY_ZONES = _get_env_list("PUBLIC_SUBNET_AVAILABILITY_ZONES") | |
if PRIVATE_SUBNET_CIDR_BLOCKS: PRIVATE_SUBNET_CIDR_BLOCKS = _get_env_list("PRIVATE_SUBNET_CIDR_BLOCKS") | |
if PRIVATE_SUBNET_AVAILABILITY_ZONES: PRIVATE_SUBNET_AVAILABILITY_ZONES = _get_env_list("PRIVATE_SUBNET_AVAILABILITY_ZONES") | |
if AWS_MANAGED_TASK_ROLES_LIST: AWS_MANAGED_TASK_ROLES_LIST = _get_env_list(AWS_MANAGED_TASK_ROLES_LIST) | |
class CdkStack(Stack): | |
def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: | |
super().__init__(scope, construct_id, **kwargs) | |
# --- Helper to get context values --- | |
def get_context_bool(key: str, default: bool = False) -> bool: | |
return self.node.try_get_context(key) or default | |
def get_context_str(key: str, default: str = None) -> str: | |
return self.node.try_get_context(key) or default | |
def get_context_dict(key: str, default: dict = None) -> dict: | |
return self.node.try_get_context(key) or default | |
def get_context_list_of_dicts(key: str) -> List[Dict[str, Any]]: | |
ctx_value = self.node.try_get_context(key) | |
if not isinstance(ctx_value, list): | |
print(f"Warning: Context key '{key}' not found or not a list. Returning empty list.") | |
return [] | |
# Optional: Add validation that all items in the list are dicts | |
return ctx_value | |
# --- VPC and Subnets (Assuming VPC is always lookup, Subnets are created/returned by create_subnets) --- | |
# --- VPC Lookup (Always lookup as per your assumption) --- | |
try: | |
vpc = ec2.Vpc.from_lookup( | |
self, | |
"VPC", | |
vpc_name=VPC_NAME | |
) | |
print("Successfully looked up VPC:", vpc.vpc_id) | |
except Exception as e: | |
raise Exception(f"Could not look up VPC with name '{VPC_NAME}' due to: {e}") | |
# --- Subnet Handling (Check Context and Create/Import) --- | |
# Initialize lists to hold ISubnet objects (L2) and CfnSubnet/CfnRouteTable (L1) | |
# We will store ISubnet for consistency, as CfnSubnet has a .subnet_id property | |
self.public_subnets: List[ec2.ISubnet] = [] | |
self.private_subnets: List[ec2.ISubnet] = [] | |
# Store L1 CfnRouteTables explicitly if you need to reference them later | |
self.private_route_tables_cfn: List[ec2.CfnRouteTable] = [] | |
self.public_route_tables_cfn: List[ec2.CfnRouteTable] = [] # New: to store public RTs | |
names_to_create_private = [] | |
names_to_create_public = [] | |
if not PUBLIC_SUBNETS_TO_USE and not PRIVATE_SUBNETS_TO_USE: | |
print("Warning: No public or private subnets specified in *_SUBNETS_TO_USE. Attempting to select from existing VPC subnets.") | |
print("vpc.public_subnets:", vpc.public_subnets) | |
print("vpc.private_subnets:", vpc.private_subnets) | |
# public_subnets_by_az: Dict[str, List[ec2.ISubnet]] = {} | |
# private_subnets_by_az: Dict[str, List[ec2.ISubnet]] = {} | |
# Iterate through the subnets exposed by the Vpc L2 construct. | |
# for subnet in vpc.public_subnets: | |
# az = subnet.availability_zone | |
# if az not in public_subnets_by_az: | |
# public_subnets_by_az[az] = [] | |
# public_subnets_by_az[az].append(subnet) | |
selected_public_subnets = vpc.select_subnets(subnet_type=ec2.SubnetType.PUBLIC, one_per_az=True) | |
private_subnets_egress = vpc.select_subnets(subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS, one_per_az=True) | |
private_subnets_isolated = vpc.select_subnets(subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, one_per_az=True) | |
combined_subnet_objects = [] | |
if private_subnets_egress.subnets: | |
# Add the first PRIVATE_WITH_EGRESS subnet | |
combined_subnet_objects.append(private_subnets_egress.subnets[0]) | |
else: | |
self.node.add_warning("No PRIVATE_WITH_EGRESS subnets found to select the first one.") | |
# Add all PRIVATE_ISOLATED subnets *except* the first one (if they exist) | |
if len(private_subnets_isolated.subnets) > 1: | |
combined_subnet_objects.extend(private_subnets_isolated.subnets[1:]) | |
elif private_subnets_isolated.subnets: # Only 1 isolated subnet, add a warning if [1:] was desired | |
self.node.add_warning("Only one PRIVATE_ISOLATED subnet found, private_subnets_isolated.subnets[1:] will be empty.") | |
else: | |
self.node.add_warning("No PRIVATE_ISOLATED subnets found.") | |
# Create an ec2.SelectedSubnets object from the combined private subnet list. | |
selected_private_subnets = vpc.select_subnets( | |
subnets=combined_subnet_objects | |
) | |
print("selected_public_subnets:", selected_public_subnets) | |
print("selected_private_subnets:", selected_private_subnets) | |
#self.private_route_tables_cfn = [] | |
# for subnet in vpc.private_subnets: | |
# az = subnet.availability_zone | |
# if az not in private_subnets_by_az: | |
# private_subnets_by_az[az] = [] | |
# private_subnets_by_az[az].append(subnet) | |
#selected_public_subnets: List[ec2.ISubnet] = [] | |
#selected_private_subnets: List[ec2.ISubnet] = [] | |
# Select one public subnet per AZ, preferring the first one found | |
# for az in sorted(public_subnets_by_az.keys()): | |
# if public_subnets_by_az[az]: | |
# selected_public_subnets.append(public_subnets_by_az[az][0]) | |
# print(f"Selected existing public subnet: {public_subnets_by_az[az][0].subnet_id} from AZ {az}.") | |
# Select one private subnet per AZ, preferring the first one found | |
# for az in sorted(private_subnets_by_az.keys()): | |
# if private_subnets_by_az[az]: | |
# selected_private_subnets.append(private_subnets_by_az[az][0]) | |
# print(f"Selected existing private subnet: {private_subnets_by_az[az][0].subnet_id} from AZ {az}.") | |
if len(selected_public_subnets.subnet_ids) < 2 or len(selected_private_subnets.subnet_ids) < 2: | |
raise Exception("Need at least two public or private subnets in different availability zones") | |
if not selected_public_subnets and not selected_private_subnets: | |
# If no subnets could be found even with automatic selection, raise an error. | |
# This ensures the stack doesn't proceed if it absolutely needs subnets. | |
print("Error: No existing public or private subnets could be found in the VPC for automatic selection. " | |
"You must either specify subnets in *_SUBNETS_TO_USE or ensure the VPC has discoverable subnets.") | |
raise RuntimeError("No suitable subnets found for automatic selection.") | |
else: | |
self.public_subnets = selected_public_subnets.subnets | |
self.private_subnets = selected_private_subnets.subnets | |
print(f"Automatically selected {len(self.public_subnets)} public and {len(self.private_subnets)} private subnets based on VPC discovery.") | |
print("self.public_subnets:", self.public_subnets) | |
print("self.private_subnets:", self.private_subnets) | |
# Since subnets are now assigned, we can exit this processing block. | |
# The rest of the original code (which iterates *_SUBNETS_TO_USE) will be skipped. | |
checked_public_subnets_ctx = get_context_dict("checked_public_subnets") | |
checked_private_subnets_ctx = get_context_dict("checked_private_subnets") | |
public_subnets_data_for_creation_ctx = get_context_list_of_dicts("public_subnets_to_create") | |
private_subnets_data_for_creation_ctx = get_context_list_of_dicts("private_subnets_to_create") | |
# --- 3. Process Public Subnets --- | |
print("\n--- Processing Public Subnets ---") | |
# Import existing public subnets | |
if checked_public_subnets_ctx: | |
for i, subnet_name in enumerate(PUBLIC_SUBNETS_TO_USE): | |
subnet_info = checked_public_subnets_ctx.get(subnet_name) | |
if subnet_info and subnet_info.get("exists"): | |
subnet_id = subnet_info.get("id") | |
if not subnet_id: | |
raise RuntimeError(f"Context for existing public subnet '{subnet_name}' is missing 'id'.") | |
try: | |
imported_subnet = ec2.Subnet.from_subnet_id( | |
self, f"ImportedPublicSubnet{subnet_name.replace('-', '')}{i}", subnet_id | |
) | |
#self.public_subnets.append(imported_subnet) | |
print(f"Imported existing public subnet: {subnet_name} (ID: {subnet_id})") | |
except Exception as e: | |
raise RuntimeError(f"Failed to import public subnet '{subnet_name}' with ID '{subnet_id}'. Error: {e}") | |
# Create new public subnets based on public_subnets_data_for_creation_ctx | |
if public_subnets_data_for_creation_ctx: | |
names_to_create_public = [s['name'] for s in public_subnets_data_for_creation_ctx] | |
cidrs_to_create_public = [s['cidr'] for s in public_subnets_data_for_creation_ctx] | |
azs_to_create_public = [s['az'] for s in public_subnets_data_for_creation_ctx] | |
if names_to_create_public: | |
print(f"Attempting to create {len(names_to_create_public)} new public subnets: {names_to_create_public}") | |
newly_created_public_subnets, newly_created_public_rts_cfn = create_subnets( | |
self, vpc, CDK_PREFIX, names_to_create_public, cidrs_to_create_public, azs_to_create_public, | |
is_public=True, | |
internet_gateway_id=EXISTING_IGW_ID | |
) | |
self.public_subnets.extend(newly_created_public_subnets) | |
self.public_route_tables_cfn.extend(newly_created_public_rts_cfn) | |
if not self.public_subnets: | |
raise Exception("No public subnets found or created, exiting.") | |
# --- NAT Gateway Creation/Lookup --- | |
self.single_nat_gateway_id = None | |
nat_gw_id_from_context = SINGLE_NAT_GATEWAY_ID | |
if nat_gw_id_from_context: | |
print(f"Using existing NAT Gateway ID from context: {nat_gw_id_from_context}") | |
self.single_nat_gateway_id = nat_gw_id_from_context | |
else: | |
# If not in context, create a new one, but only if we have a public subnet. | |
if self.public_subnets: | |
print("NAT Gateway ID not found in context. Creating a new one.") | |
# Place the NAT GW in the first available public subnet | |
first_public_subnet = self.public_subnets[0] | |
self.single_nat_gateway_id = create_nat_gateway( | |
self, | |
first_public_subnet, | |
nat_gateway_name=NAT_GATEWAY_NAME, | |
nat_gateway_id_context_key=SINGLE_NAT_GATEWAY_ID | |
) | |
else: | |
print("WARNING: No public subnets available. Cannot create a NAT Gateway.") | |
# --- 4. Process Private Subnets --- | |
print("\n--- Processing Private Subnets ---") | |
# ... (rest of your existing subnet processing logic for checked_private_subnets_ctx) ... | |
# (This part for importing existing subnets remains the same) | |
# Create new private subnets | |
if private_subnets_data_for_creation_ctx: | |
names_to_create_private = [s['name'] for s in private_subnets_data_for_creation_ctx] | |
cidrs_to_create_private = [s['cidr'] for s in private_subnets_data_for_creation_ctx] | |
azs_to_create_private = [s['az'] for s in private_subnets_data_for_creation_ctx] | |
if names_to_create_private: | |
print(f"Attempting to create {len(names_to_create_private)} new private subnets: {names_to_create_private}") | |
# --- CALL THE NEW CREATE_SUBNETS FUNCTION FOR PRIVATE --- | |
# Ensure self.single_nat_gateway_id is available before this call | |
if not self.single_nat_gateway_id: | |
raise ValueError("A single NAT Gateway ID is required for private subnets but was not resolved.") | |
newly_created_private_subnets_cfn, newly_created_private_rts_cfn = create_subnets( | |
self, vpc, CDK_PREFIX, names_to_create_private, cidrs_to_create_private, azs_to_create_private, | |
is_public=False, | |
single_nat_gateway_id=self.single_nat_gateway_id # Pass the single NAT Gateway ID | |
) | |
self.private_subnets.extend(newly_created_private_subnets_cfn) | |
self.private_route_tables_cfn.extend(newly_created_private_rts_cfn) | |
print(f"Successfully defined {len(newly_created_private_subnets_cfn)} new private subnets and their route tables for creation.") | |
else: | |
print("No private subnets specified for creation in context ('private_subnets_to_create').") | |
if not self.private_subnets: | |
raise Exception("No private subnets found or created, exiting.") | |
# --- 5. Sanity Check and Output --- | |
# Output the single NAT Gateway ID for verification | |
if self.single_nat_gateway_id: | |
CfnOutput(self, "SingleNatGatewayId", value=self.single_nat_gateway_id, | |
description="ID of the single NAT Gateway used for private subnets.") | |
else: | |
raise Exception("No single NAT Gateway was created or resolved.") | |
# --- Outputs for other stacks/regions --- | |
# These are crucial for cross-stack, cross-region referencing | |
self.params = dict() | |
self.params["vpc_id"] = vpc.vpc_id | |
self.params["private_subnets"] = self.private_subnets | |
self.params["private_route_tables"] = self.private_route_tables_cfn | |
self.params["public_subnets"] = self.public_subnets | |
self.params["public_route_tables"] = self.public_route_tables_cfn | |
#class CdkStackMain(Stack): | |
# def __init__(self, scope: Construct, construct_id: str, private_subnets:List[ec2.ISubnet]=[], private_route_tables: List[ec2.CfnRouteTable]=[], public_subnets:List[ec2.ISubnet]=[], public_route_tables: List[ec2.CfnRouteTable]=[], **kwargs) -> None: | |
# super().__init__(scope, construct_id, **kwargs) | |
# --- Helper to get context values --- | |
# def get_context_bool(key: str, default: bool = False) -> bool: | |
# return self.node.try_get_context(key) or default | |
# def get_context_str(key: str, default: str = None) -> str: | |
# return self.node.try_get_context(key) or default | |
# def get_context_dict(key: str, default: dict = None) -> dict: | |
# return self.node.try_get_context(key) or default | |
# def get_context_list_of_dicts(key: str) -> List[Dict[str, Any]]: | |
# ctx_value = self.node.try_get_context(key) | |
# if not isinstance(ctx_value, list): | |
# print(f"Warning: Context key '{key}' not found or not a list. Returning empty list.") | |
# return [] | |
# # Optional: Add validation that all items in the list are dicts | |
# return ctx_value | |
# self.private_subnets: List[ec2.ISubnet] = private_subnets | |
# self.private_route_tables_cfn: List[ec2.CfnRouteTable] = private_route_tables | |
# self.public_subnets: List[ec2.ISubnet] = public_subnets | |
# self.public_route_tables_cfn: List[ec2.CfnRouteTable] = public_route_tables | |
private_subnet_selection = ec2.SubnetSelection(subnets=self.private_subnets) | |
public_subnet_selection = ec2.SubnetSelection(subnets=self.public_subnets) | |
for sub in private_subnet_selection.subnets: | |
print("private subnet:", sub.subnet_id, "is in availability zone:", sub.availability_zone) | |
for sub in public_subnet_selection.subnets: | |
print("public subnet:", sub.subnet_id, "is in availability zone:", sub.availability_zone) | |
# try: | |
# vpc = ec2.Vpc.from_lookup( | |
# self, | |
# "VPC", | |
# vpc_name=VPC_NAME | |
# ) | |
# print("Successfully looked up VPC") | |
# except Exception as e: | |
# raise Exception(f"Could not look up VPC with name '{VPC_NAME}' due to: {e}") | |
print("Private subnet route tables:", self.private_route_tables_cfn) | |
# Add the S3 Gateway Endpoint to the VPC | |
if names_to_create_private: | |
try: | |
s3_gateway_endpoint = vpc.add_gateway_endpoint( | |
"S3GatewayEndpoint", | |
service=ec2.GatewayVpcEndpointAwsService.S3, subnets=[private_subnet_selection]) | |
except Exception as e: | |
print("Could not add S3 gateway endpoint to subnets due to:", e) | |
#Output some useful information | |
CfnOutput(self, "VpcIdOutput", value=vpc.vpc_id, | |
description="The ID of the VPC where the S3 Gateway Endpoint is deployed.") | |
CfnOutput(self, "S3GatewayEndpointService", value=s3_gateway_endpoint.vpc_endpoint_id, | |
description="The id for the S3 Gateway Endpoint.") # Specify the S3 service | |
# --- IAM Roles --- | |
try: | |
codebuild_role_name = CODEBUILD_ROLE_NAME | |
custom_sts_kms_policy = """{ | |
"Version": "2012-10-17", | |
"Statement": [ | |
{ | |
"Sid": "STSCallerIdentity", | |
"Effect": "Allow", | |
"Action": [ | |
"sts:GetCallerIdentity" | |
], | |
"Resource": "*" | |
}, | |
{ | |
"Sid": "KMSAccess", | |
"Effect": "Allow", | |
"Action": [ | |
"kms:Encrypt", | |
"kms:Decrypt", | |
"kms:GenerateDataKey" | |
], | |
"Resource": "*" | |
} | |
] | |
}""" | |
if get_context_bool(f"exists:{codebuild_role_name}"): | |
# If exists, lookup/import the role using ARN from context | |
role_arn = get_context_str(f"arn:{codebuild_role_name}") | |
if not role_arn: | |
raise ValueError(f"Context value 'arn:{codebuild_role_name}' is required if role exists.") | |
codebuild_role = iam.Role.from_role_arn(self, "CodeBuildRole", role_arn=role_arn) | |
print("Using existing CodeBuild role") | |
else: | |
# If not exists, create the role | |
codebuild_role = iam.Role( | |
self, "CodeBuildRole", # Logical ID | |
role_name=codebuild_role_name, # Explicit resource name | |
assumed_by=iam.ServicePrincipal("codebuild.amazonaws.com") | |
) | |
codebuild_role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name(f"EC2InstanceProfileForImageBuilderECRContainerBuilds")) | |
print("Successfully created new CodeBuild role") | |
task_role_name = ECS_TASK_ROLE_NAME | |
if get_context_bool(f"exists:{task_role_name}"): | |
role_arn = get_context_str(f"arn:{task_role_name}") | |
if not role_arn: | |
raise ValueError(f"Context value 'arn:{task_role_name}' is required if role exists.") | |
task_role = iam.Role.from_role_arn(self, "TaskRole", role_arn=role_arn) | |
print("Using existing ECS task role") | |
else: | |
task_role = iam.Role( | |
self, "TaskRole", # Logical ID | |
role_name=task_role_name, # Explicit resource name | |
assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com") | |
) | |
for role in AWS_MANAGED_TASK_ROLES_LIST: | |
print(f"Adding {role} to policy") | |
task_role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name(f"{role}")) | |
task_role = add_custom_policies(self, task_role, custom_policy_text=custom_sts_kms_policy) | |
print("Successfully created new ECS task role") | |
execution_role_name = ECS_TASK_EXECUTION_ROLE_NAME | |
if get_context_bool(f"exists:{execution_role_name}"): | |
role_arn = get_context_str(f"arn:{execution_role_name}") | |
if not role_arn: | |
raise ValueError(f"Context value 'arn:{execution_role_name}' is required if role exists.") | |
execution_role = iam.Role.from_role_arn(self, "ExecutionRole", role_arn=role_arn) | |
print("Using existing ECS execution role") | |
else: | |
execution_role = iam.Role( | |
self, "ExecutionRole", # Logical ID | |
role_name=execution_role_name, # Explicit resource name | |
assumed_by=iam.ServicePrincipal("ecs-tasks.amazonaws.com") | |
) | |
for role in AWS_MANAGED_TASK_ROLES_LIST: | |
execution_role.add_managed_policy(iam.ManagedPolicy.from_aws_managed_policy_name(f"{role}")) | |
execution_role = add_custom_policies(self, execution_role, custom_policy_text=custom_sts_kms_policy) | |
print("Successfully created new ECS execution role") | |
except Exception as e: | |
raise Exception("Failed at IAM role step due to:", e) | |
# --- S3 Buckets --- | |
try: | |
log_bucket_name = S3_LOG_CONFIG_BUCKET_NAME | |
if get_context_bool(f"exists:{log_bucket_name}"): | |
bucket = s3.Bucket.from_bucket_name(self, "LogConfigBucket", bucket_name=log_bucket_name) | |
print("Using existing S3 bucket", log_bucket_name) | |
else: | |
bucket = s3.Bucket(self, "LogConfigBucket", bucket_name=log_bucket_name, | |
versioned=False, # Set to True if you need versioning | |
# IMPORTANT: Set removal_policy to DESTROY | |
removal_policy=RemovalPolicy.DESTROY, | |
# IMPORTANT: Set auto_delete_objects to True to empty the bucket before deletion | |
auto_delete_objects=True | |
) # Explicitly set bucket_name | |
print("Created S3 bucket", log_bucket_name) | |
# Add policies - this will apply to both created and imported buckets | |
# CDK handles idempotent policy additions | |
bucket.add_to_resource_policy( | |
iam.PolicyStatement( | |
effect=iam.Effect.ALLOW, | |
principals=[task_role], # Pass the role object directly | |
actions=["s3:GetObject", "s3:PutObject"], | |
resources=[f"{bucket.bucket_arn}/*"] | |
) | |
) | |
bucket.add_to_resource_policy( | |
iam.PolicyStatement( | |
effect=iam.Effect.ALLOW, | |
principals=[task_role], | |
actions=["s3:ListBucket"], | |
resources=[bucket.bucket_arn] | |
) | |
) | |
output_bucket_name = S3_OUTPUT_BUCKET_NAME | |
if get_context_bool(f"exists:{output_bucket_name}"): | |
output_bucket = s3.Bucket.from_bucket_name(self, "OutputBucket", bucket_name=output_bucket_name) | |
print("Using existing Output bucket", output_bucket_name) | |
else: | |
output_bucket = s3.Bucket(self, "OutputBucket", bucket_name=output_bucket_name, | |
lifecycle_rules=[ | |
s3.LifecycleRule( | |
expiration=Duration.days(int(DAYS_TO_DISPLAY_WHOLE_DOCUMENT_JOBS)) | |
) | |
], | |
versioned=False, # Set to True if you need versioning | |
# IMPORTANT: Set removal_policy to DESTROY | |
removal_policy=RemovalPolicy.DESTROY, | |
# IMPORTANT: Set auto_delete_objects to True to empty the bucket before deletion | |
auto_delete_objects=True | |
) | |
print("Created Output bucket:", output_bucket_name) | |
# Add policies to output bucket | |
output_bucket.add_to_resource_policy( | |
iam.PolicyStatement( | |
effect=iam.Effect.ALLOW, | |
principals=[task_role], | |
actions=["s3:GetObject", "s3:PutObject"], | |
resources=[f"{output_bucket.bucket_arn}/*"] | |
) | |
) | |
output_bucket.add_to_resource_policy( | |
iam.PolicyStatement( | |
effect=iam.Effect.ALLOW, | |
principals=[task_role], | |
actions=["s3:ListBucket"], | |
resources=[output_bucket.bucket_arn] | |
) | |
) | |
except Exception as e: | |
raise Exception("Could not handle S3 buckets due to:", e) | |
# --- Elastic Container Registry --- | |
try: | |
full_ecr_repo_name = ECR_CDK_REPO_NAME | |
if get_context_bool(f"exists:{full_ecr_repo_name}"): | |
ecr_repo = ecr.Repository.from_repository_name(self, "ECRRepo", repository_name=full_ecr_repo_name) | |
print("Using existing ECR repository") | |
else: | |
ecr_repo = ecr.Repository(self, "ECRRepo", repository_name=full_ecr_repo_name) # Explicitly set repository_name | |
print("Created ECR repository", full_ecr_repo_name) | |
ecr_image_loc = ecr_repo.repository_uri | |
except Exception as e: | |
raise Exception("Could not handle ECR repo due to:", e) | |
# --- CODEBUILD --- | |
try: | |
codebuild_project_name = CODEBUILD_PROJECT_NAME | |
if get_context_bool(f"exists:{codebuild_project_name}"): | |
# Lookup CodeBuild project by ARN from context | |
project_arn = get_context_str(f"arn:{codebuild_project_name}") | |
if not project_arn: | |
raise ValueError(f"Context value 'arn:{codebuild_project_name}' is required if project exists.") | |
codebuild_project = codebuild.Project.from_project_arn(self, "CodeBuildProject", project_arn=project_arn) | |
print("Using existing CodeBuild project") | |
else: | |
codebuild_project = codebuild.Project(self, | |
"CodeBuildProject", # Logical ID | |
project_name=codebuild_project_name, # Explicit resource name | |
source=codebuild.Source.git_hub( | |
owner=GITHUB_REPO_USERNAME, | |
repo=GITHUB_REPO_NAME, | |
branch_or_ref=GITHUB_REPO_BRANCH | |
), | |
environment=codebuild.BuildEnvironment( | |
build_image=codebuild.LinuxBuildImage.STANDARD_7_0, | |
privileged=True, | |
environment_variables={"ECR_REPO_NAME": codebuild.BuildEnvironmentVariable(value=full_ecr_repo_name), | |
"AWS_DEFAULT_REGION": codebuild.BuildEnvironmentVariable(value=AWS_REGION), | |
"AWS_ACCOUNT_ID": codebuild.BuildEnvironmentVariable(value=AWS_ACCOUNT_ID)} | |
), | |
build_spec=codebuild.BuildSpec.from_object({ | |
"version": "0.2", | |
"phases": { | |
"pre_build": { | |
"commands": [ | |
"echo Logging in to Amazon ECR", | |
"aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com" | |
] | |
}, | |
"build": { | |
"commands": [ | |
"echo Building the Docker image", | |
"docker build -t $ECR_REPO_NAME:latest .", | |
"docker tag $ECR_REPO_NAME:latest $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest" | |
] | |
}, | |
"post_build": { | |
"commands": [ | |
"echo Pushing the Docker image", | |
"docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$ECR_REPO_NAME:latest" | |
] | |
} | |
} | |
}) | |
) | |
print("Successfully created CodeBuild project", codebuild_project_name) | |
# Grant permissions - applies to both created and imported project role | |
ecr_repo.grant_pull_push(codebuild_project.role) | |
except Exception as e: | |
raise Exception("Could not handle Codebuild project due to:", e) | |
# --- Security Groups --- | |
try: | |
ecs_security_group_name = ECS_SECURITY_GROUP_NAME | |
# Following checks by name don't really work | |
# Use CDK's from_lookup_by_name which handles lookup or throws an error if not found | |
#try: | |
# ecs_security_group = ec2.SecurityGroup.from_lookup_by_name( | |
# self, "ECSSecurityGroup", vpc=vpc, security_group_name=ecs_security_group_name | |
# ) | |
# print(f"Using existing Security Group: {ecs_security_group_name}") | |
# except Exception: # If lookup fails, create | |
try: | |
ecs_security_group = ec2.SecurityGroup( | |
self, | |
"ECSSecurityGroup", # Logical ID | |
security_group_name=ecs_security_group_name, # Explicit resource name | |
vpc=vpc, | |
) | |
print(f"Created Security Group: {ecs_security_group_name}") | |
except Exception as e: # If lookup fails, create | |
print("Failed to create ECS security group due to:", e) | |
alb_security_group_name = ALB_NAME_SECURITY_GROUP_NAME | |
# try: | |
# alb_security_group = ec2.SecurityGroup.from_lookup_by_name( | |
# self, "ALBSecurityGroup", vpc=vpc, security_group_name=alb_security_group_name | |
# ) | |
# print(f"Using existing Security Group: {alb_security_group_name}") | |
# except Exception: # If lookup fails, create | |
try: | |
alb_security_group = ec2.SecurityGroup( | |
self, | |
"ALBSecurityGroup", # Logical ID | |
security_group_name=alb_security_group_name, # Explicit resource name | |
vpc=vpc | |
) | |
print(f"Created Security Group: {alb_security_group_name}") | |
except Exception as e: # If lookup fails, create | |
print("Failed to create ALB security group due to:", e) | |
# Define Ingress Rules - CDK will manage adding/removing these as needed | |
ec2_port_gradio_server_port = ec2.Port.tcp(int(GRADIO_SERVER_PORT)) # Ensure port is int | |
ecs_security_group.add_ingress_rule( | |
peer=alb_security_group, | |
connection=ec2_port_gradio_server_port, | |
description="ALB traffic", | |
) | |
alb_security_group.add_ingress_rule( | |
peer=ec2.Peer.prefix_list("pl-93a247fa"), | |
connection=ec2.Port.all_traffic(), | |
description="CloudFront traffic", | |
) | |
except Exception as e: | |
raise Exception("Could not handle security groups due to:", e) | |
# --- DynamoDB tables for logs (optional) --- | |
if SAVE_LOGS_TO_DYNAMODB == 'True': | |
try: | |
print("Creating DynamoDB tables for logs") | |
dynamodb_table_access = dynamodb.Table(self, "RedactionAccessDataTable", | |
table_name=ACCESS_LOG_DYNAMODB_TABLE_NAME, | |
partition_key=dynamodb.Attribute( | |
name="id", | |
type=dynamodb.AttributeType.STRING), | |
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, | |
removal_policy=RemovalPolicy.DESTROY) | |
dynamodb_table_feedback = dynamodb.Table(self, "RedactionFeedbackDataTable", | |
table_name=FEEDBACK_LOG_DYNAMODB_TABLE_NAME, | |
partition_key=dynamodb.Attribute( | |
name="id", | |
type=dynamodb.AttributeType.STRING), | |
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, | |
removal_policy=RemovalPolicy.DESTROY) | |
dynamodb_table_usage = dynamodb.Table(self, "RedactionUsageDataTable", | |
table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, | |
partition_key=dynamodb.Attribute( | |
name="id", | |
type=dynamodb.AttributeType.STRING), | |
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, | |
removal_policy=RemovalPolicy.DESTROY) | |
except Exception as e: | |
raise Exception("Could not create DynamoDB tables due to:", e) | |
# --- ALB --- | |
try: | |
load_balancer_name = ALB_NAME | |
if len(load_balancer_name) > 32: load_balancer_name = load_balancer_name[-32:] | |
if get_context_bool(f"exists:{load_balancer_name}"): | |
# Lookup ALB by ARN from context | |
alb_arn = get_context_str(f"arn:{load_balancer_name}") | |
if not alb_arn: | |
raise ValueError(f"Context value 'arn:{load_balancer_name}' is required if ALB exists.") | |
alb = elbv2.ApplicationLoadBalancer.from_lookup( | |
self, "ALB", # Logical ID | |
load_balancer_arn=alb_arn | |
) | |
print(f"Using existing Application Load Balancer {load_balancer_name}.") | |
else: | |
alb = elbv2.ApplicationLoadBalancer( | |
self, | |
"ALB", # Logical ID | |
load_balancer_name=load_balancer_name, # Explicit resource name | |
vpc=vpc, | |
internet_facing=True, | |
security_group=alb_security_group, # Link to SG | |
vpc_subnets=public_subnet_selection # Link to subnets | |
) | |
print("Successfully created new Application Load Balancer") | |
except Exception as e: | |
raise Exception("Could not handle application load balancer due to:", e) | |
# --- Cognito User Pool --- | |
try: | |
if get_context_bool(f"exists:{COGNITO_USER_POOL_NAME}"): | |
# Lookup by ID from context | |
user_pool_id = get_context_str(f"id:{COGNITO_USER_POOL_NAME}") | |
if not user_pool_id: | |
raise ValueError(f"Context value 'id:{COGNITO_USER_POOL_NAME}' is required if User Pool exists.") | |
user_pool = cognito.UserPool.from_user_pool_id(self, "UserPool", user_pool_id=user_pool_id) | |
print(f"Using existing user pool {user_pool_id}.") | |
else: | |
user_pool = cognito.UserPool(self, "UserPool", | |
user_pool_name=COGNITO_USER_POOL_NAME, | |
mfa=cognito.Mfa.OFF, # Adjust as needed | |
sign_in_aliases=cognito.SignInAliases(email=True), | |
removal_policy=RemovalPolicy.DESTROY) # Adjust as needed | |
print(f"Created new user pool {user_pool.user_pool_id}.") | |
# If you're using a certificate, assume that you will be using the ALB Cognito login features. You need different redirect URLs to accept the token that comes from Cognito authentication. | |
if ACM_CERTIFICATE_ARN: | |
redirect_uris = [COGNITO_REDIRECTION_URL, COGNITO_REDIRECTION_URL + "/oauth2/idpresponse"] | |
else: | |
redirect_uris = [COGNITO_REDIRECTION_URL] | |
user_pool_client_name = COGNITO_USER_POOL_CLIENT_NAME | |
if get_context_bool(f"exists:{user_pool_client_name}"): | |
# Lookup by ID from context (requires User Pool object) | |
user_pool_client_id = get_context_str(f"id:{user_pool_client_name}") | |
if not user_pool_client_id: | |
raise ValueError(f"Context value 'id:{user_pool_client_name}' is required if User Pool Client exists.") | |
user_pool_client = cognito.UserPoolClient.from_user_pool_client_id(self, "UserPoolClient", user_pool_client_id=user_pool_client_id) | |
print(f"Using existing user pool client {user_pool_client_id}.") | |
else: | |
user_pool_client = cognito.UserPoolClient(self, "UserPoolClient", | |
auth_flows=cognito.AuthFlow(user_srp=True, user_password=True), # Example: enable SRP for secure sign-in | |
user_pool=user_pool, | |
generate_secret=True, | |
user_pool_client_name=user_pool_client_name, | |
supported_identity_providers=[cognito.UserPoolClientIdentityProvider.COGNITO], | |
o_auth=cognito.OAuthSettings( | |
flows=cognito.OAuthFlows(authorization_code_grant=True), | |
scopes=[cognito.OAuthScope.OPENID, cognito.OAuthScope.EMAIL, cognito.OAuthScope.PROFILE], | |
callback_urls=redirect_uris | |
) | |
) | |
CfnOutput(self, "CognitoAppClientId", value=user_pool_client.user_pool_client_id) | |
print(f"Created new user pool client {user_pool_client.user_pool_client_id}.") | |
# Add a domain to the User Pool (crucial for ALB integration) | |
user_pool_domain = user_pool.add_domain( | |
"UserPoolDomain", | |
cognito_domain=cognito.CognitoDomainOptions( | |
domain_prefix=COGNITO_USER_POOL_DOMAIN_PREFIX) | |
) | |
# Apply removal_policy to the created UserPoolDomain construct | |
user_pool_domain.apply_removal_policy(policy=RemovalPolicy.DESTROY) | |
CfnOutput(self, "CognitoUserPoolLoginUrl", value=user_pool_domain.base_url()) | |
except Exception as e: | |
raise Exception("Could not handle Cognito resources due to:", e) | |
# --- Secrets Manager Secret --- | |
try: | |
secret_name = COGNITO_USER_POOL_CLIENT_SECRET_NAME | |
if get_context_bool(f"exists:{secret_name}"): | |
# Lookup by name | |
secret = secretsmanager.Secret.from_secret_name_v2(self, "CognitoSecret", secret_name=secret_name) | |
print(f"Using existing Secret {secret_name}.") | |
else: | |
secret = secretsmanager.Secret(self, "CognitoSecret", # Logical ID | |
secret_name=secret_name, # Explicit resource name | |
secret_object_value={ | |
"REDACTION_USER_POOL_ID": SecretValue.unsafe_plain_text(user_pool.user_pool_id), # Use the CDK attribute | |
"REDACTION_CLIENT_ID": SecretValue.unsafe_plain_text(user_pool_client.user_pool_client_id), # Use the CDK attribute | |
"REDACTION_CLIENT_SECRET": user_pool_client.user_pool_client_secret # Use the CDK attribute | |
} | |
) | |
print(f"Created new secret {secret_name}.") | |
except Exception as e: | |
raise Exception("Could not handle Secrets Manager secret due to:", e) | |
# --- Fargate Task Definition --- | |
try: | |
# For task definitions, re-creating with the same logical ID creates new revisions. | |
# If you want to use a *specific existing revision*, you'd need to look it up by ARN. | |
# If you want to update the latest revision, defining it here is the standard. | |
# Let's assume we always define it here to get revision management. | |
fargate_task_definition_name = FARGATE_TASK_DEFINITION_NAME | |
read_only_file_system = ECS_READ_ONLY_FILE_SYSTEM == 'True' | |
if os.path.exists(TASK_DEFINITION_FILE_LOCATION): | |
with open(TASK_DEFINITION_FILE_LOCATION) as f: # Use correct path | |
task_def_params = json.load(f) | |
# Need to ensure taskRoleArn and executionRoleArn in JSON are correct ARN strings | |
else: | |
epheremal_storage_volume_name = "appEphemeralVolume" | |
task_def_params = {} | |
task_def_params['taskRoleArn'] = task_role.role_arn # Use CDK role object ARN | |
task_def_params['executionRoleArn'] = execution_role.role_arn # Use CDK role object ARN | |
task_def_params['memory'] = ECS_TASK_MEMORY_SIZE | |
task_def_params['cpu'] = ECS_TASK_CPU_SIZE | |
container_def = { | |
"name": full_ecr_repo_name, | |
"image": ecr_image_loc + ":latest", | |
"essential": True, | |
"portMappings": [{"containerPort": int(GRADIO_SERVER_PORT), "hostPort": int(GRADIO_SERVER_PORT), "protocol": "tcp", "appProtocol": "http"}], | |
"logConfiguration": {"logDriver": "awslogs", "options": {"awslogs-group": ECS_LOG_GROUP_NAME, "awslogs-region": AWS_REGION, "awslogs-stream-prefix": "ecs"}}, | |
"environmentFiles": [{"value": bucket.bucket_arn + "/config.env", "type": "s3"}], | |
"memoryReservation": int(task_def_params['memory']) - 512, # Reserve some memory for the container | |
"mountPoints": [ | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/home/user/app/logs", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/home/user/app/feedback", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/home/user/app/usage", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/home/user/app/input", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/home/user/app/output", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/home/user/app/tmp", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/home/user/app/config", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/tmp/matplotlib_cache", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/tmp", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/var/tmp", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/tmp/tld", | |
"readOnly": False | |
}, | |
{ | |
"sourceVolume": epheremal_storage_volume_name, | |
"containerPath": "/tmp/gradio_tmp", | |
"readOnly": False | |
} | |
], | |
"readonlyRootFilesystem": read_only_file_system, | |
} | |
task_def_params['containerDefinitions'] = [container_def] | |
log_group_name_from_config=task_def_params['containerDefinitions'][0]['logConfiguration']['options']['awslogs-group'] | |
cdk_managed_log_group = logs.LogGroup(self, "MyTaskLogGroup", # CDK Logical ID | |
log_group_name=log_group_name_from_config, | |
retention=logs.RetentionDays.ONE_MONTH, # Example: set retention | |
removal_policy=RemovalPolicy.DESTROY # If you want it deleted when stack is deleted | |
) | |
epheremal_storage_volume_cdk_obj = ecs.Volume( | |
name=epheremal_storage_volume_name | |
) | |
fargate_task_definition = ecs.FargateTaskDefinition( | |
self, | |
"FargateTaskDefinition", # Logical ID | |
family=fargate_task_definition_name, | |
cpu=int(task_def_params['cpu']), | |
memory_limit_mib=int(task_def_params['memory']), | |
task_role=task_role, | |
execution_role=execution_role, | |
runtime_platform=ecs.RuntimePlatform( | |
cpu_architecture=ecs.CpuArchitecture.X86_64, | |
operating_system_family=ecs.OperatingSystemFamily.LINUX | |
), | |
# 1. Specify the total ephemeral storage for the task | |
ephemeral_storage_gib=21, # Minimum is 21 GiB | |
# 2. Define the volume at the task level | |
# This volume will use the ephemeral storage configured above. | |
volumes=[epheremal_storage_volume_cdk_obj] | |
) | |
print("Fargate task definition defined.") | |
# Add container definitions to the task definition object | |
if task_def_params['containerDefinitions']: | |
container_def_params = task_def_params['containerDefinitions'][0] | |
if container_def_params.get('environmentFiles'): | |
env_files = [] | |
for env_file_param in container_def_params['environmentFiles']: | |
# Need to parse the ARN to get the bucket object and key | |
env_file_arn_parts = env_file_param['value'].split(":::") | |
bucket_name_and_key = env_file_arn_parts[-1] | |
env_bucket_name, env_key = bucket_name_and_key.split("/", 1) | |
env_file = ecs.EnvironmentFile.from_bucket(bucket, env_key) | |
env_files.append(env_file) | |
container = fargate_task_definition.add_container( | |
container_def_params['name'], | |
image=ecs.ContainerImage.from_registry(container_def_params['image']), | |
logging=ecs.LogDriver.aws_logs( | |
stream_prefix=container_def_params['logConfiguration']['options']['awslogs-stream-prefix'], | |
log_group=cdk_managed_log_group | |
), | |
secrets={ | |
"AWS_USER_POOL_ID": ecs.Secret.from_secrets_manager(secret, "REDACTION_USER_POOL_ID"), | |
"AWS_CLIENT_ID": ecs.Secret.from_secrets_manager(secret, "REDACTION_CLIENT_ID"), | |
"AWS_CLIENT_SECRET": ecs.Secret.from_secrets_manager(secret, "REDACTION_CLIENT_SECRET") | |
}, | |
environment_files=env_files, | |
readonly_root_filesystem=read_only_file_system | |
) | |
for port_mapping in container_def_params['portMappings']: | |
container.add_port_mappings( | |
ecs.PortMapping( | |
container_port=int(port_mapping['containerPort']), | |
host_port=int(port_mapping['hostPort']), | |
name="port-" + str(port_mapping['containerPort']), | |
app_protocol=ecs.AppProtocol.http, | |
protocol=ecs.Protocol.TCP | |
) | |
) | |
container.add_port_mappings(ecs.PortMapping( | |
container_port=80, | |
host_port=80, | |
name="port-80", | |
app_protocol=ecs.AppProtocol.http, | |
protocol=ecs.Protocol.TCP | |
)) | |
if container_def_params.get('mountPoints'): | |
mount_points=[] | |
for mount_point in container_def_params['mountPoints']: | |
mount_points.append(ecs.MountPoint(container_path=mount_point['containerPath'], read_only=mount_point['readOnly'], source_volume=epheremal_storage_volume_name)) | |
container.add_mount_points(*mount_points) | |
except Exception as e: | |
raise Exception("Could not handle Fargate task definition due to:", e) | |
# --- ECS Cluster --- | |
try: | |
cluster = ecs.Cluster( | |
self, | |
"ECSCluster", # Logical ID | |
cluster_name=CLUSTER_NAME, # Explicit resource name | |
enable_fargate_capacity_providers=True, | |
vpc=vpc | |
) | |
print("Successfully created new ECS cluster") | |
except Exception as e: | |
raise Exception("Could not handle ECS cluster due to:", e) | |
# --- ECS Service --- | |
try: | |
ecs_service_name = ECS_SERVICE_NAME | |
if ECS_USE_FARGATE_SPOT == 'True': use_fargate_spot = "FARGATE_SPOT" | |
if ECS_USE_FARGATE_SPOT == 'False': use_fargate_spot = "FARGATE" | |
# Check if service exists - from_service_arn or from_service_name (needs cluster) | |
try: | |
# from_service_name is useful if you have the cluster object | |
ecs_service = ecs.FargateService.from_service_attributes( | |
self, "ECSService", # Logical ID | |
cluster=cluster, # Requires the cluster object | |
service_name=ecs_service_name | |
) | |
print(f"Using existing ECS service {ecs_service_name}.") | |
except Exception: | |
# Service will be created with a count of 0, because you haven't yet actually built the initial Docker container with CodeBuild | |
ecs_service = ecs.FargateService( | |
self, | |
"ECSService", # Logical ID | |
service_name=ecs_service_name, # Explicit resource name | |
platform_version=ecs.FargatePlatformVersion.LATEST, | |
capacity_provider_strategies=[ecs.CapacityProviderStrategy(capacity_provider=use_fargate_spot, base=0, weight=1)], | |
cluster=cluster, | |
task_definition=fargate_task_definition, # Link to TD | |
security_groups=[ecs_security_group], # Link to SG | |
vpc_subnets=ec2.SubnetSelection(subnets=self.private_subnets), # Link to subnets | |
min_healthy_percent=0, | |
max_healthy_percent=100, | |
desired_count=0 | |
) | |
print("Successfully created new ECS service") | |
# Note: Auto-scaling setup would typically go here if needed for the service | |
except Exception as e: | |
raise Exception("Could not handle ECS service due to:", e) | |
# --- Grant Secret Read Access (Applies to both created and imported roles) --- | |
try: | |
secret.grant_read(task_role) | |
secret.grant_read(execution_role) | |
except Exception as e: | |
raise Exception("Could not grant access to Secrets Manager due to:", e) | |
# --- ALB TARGET GROUPS AND LISTENERS --- | |
# This section should primarily define the resources if they are managed by this stack. | |
# CDK handles adding/removing targets and actions on updates. | |
# If they might pre-exist outside the stack, you need lookups. | |
cookie_duration = Duration.hours(12) | |
target_group_name = ALB_TARGET_GROUP_NAME # Explicit resource name | |
cloudfront_distribution_url = "cloudfront_placeholder.net" # Need to replace this afterwards with the actual cloudfront_distribution.domain_name | |
try: | |
# --- CREATING TARGET GROUPS AND ADDING THE CLOUDFRONT LISTENER RULE --- | |
target_group = elbv2.ApplicationTargetGroup( | |
self, | |
"AppTargetGroup", # Logical ID | |
target_group_name=target_group_name, # Explicit resource name | |
port=int(GRADIO_SERVER_PORT), # Ensure port is int | |
protocol=elbv2.ApplicationProtocol.HTTP, | |
targets=[ecs_service], # Link to ECS Service | |
stickiness_cookie_duration=cookie_duration, | |
vpc=vpc, # Target Groups need VPC | |
) | |
print(f"ALB target group {target_group_name} defined.") | |
# First HTTP | |
listener_port = 80 | |
# Check if Listener exists - from_listener_arn or lookup by port/ALB | |
http_listener = alb.add_listener( | |
"HttpListener", # Logical ID | |
port=listener_port, | |
open=False, # Be cautious with open=True, usually restrict source SG | |
) | |
print(f"ALB listener on port {listener_port} defined.") | |
if ACM_CERTIFICATE_ARN: | |
http_listener.add_action( | |
"DefaultAction", # Logical ID for the default action | |
action=elbv2.ListenerAction.redirect(protocol='HTTPS', | |
host='#{host}', | |
port='443', | |
path='/#{path}', | |
query='#{query}') | |
) | |
else: | |
if USE_CLOUDFRONT == 'True': | |
# The following default action can be added for the listener after a host header rule is added to the listener manually in the Console as suggested in the above comments. | |
http_listener.add_action( | |
"DefaultAction", # Logical ID for the default action | |
action=elbv2.ListenerAction.fixed_response( | |
status_code=403, | |
content_type="text/plain", | |
message_body="Access denied", | |
), | |
) | |
# Add the Listener Rule for the specific CloudFront Host Header | |
http_listener.add_action( | |
"CloudFrontHostHeaderRule", | |
action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration), | |
priority=1, # Example priority. Adjust as needed. Lower is evaluated first. | |
conditions=[ | |
elbv2.ListenerCondition.host_headers([cloudfront_distribution_url]) # May have to redefine url in console afterwards if not specified in config file | |
] | |
) | |
else: | |
# Add the Listener Rule for the specific CloudFront Host Header | |
http_listener.add_action( | |
"CloudFrontHostHeaderRule", | |
action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration) | |
) | |
print("Added targets and actions to ALB HTTP listener.") | |
# Now the same for HTTPS if you have an ACM certificate | |
if ACM_CERTIFICATE_ARN: | |
listener_port_https = 443 | |
# Check if Listener exists - from_listener_arn or lookup by port/ALB | |
https_listener = add_alb_https_listener_with_cert( | |
self, | |
"MyHttpsListener", # Logical ID for the HTTPS listener | |
alb, | |
acm_certificate_arn=ACM_CERTIFICATE_ARN, | |
default_target_group=target_group, | |
enable_cognito_auth=True, | |
cognito_user_pool=user_pool, | |
cognito_user_pool_client=user_pool_client, | |
cognito_user_pool_domain=user_pool_domain, | |
listener_open_to_internet=True, | |
stickiness_cookie_duration=cookie_duration | |
) | |
if https_listener: | |
CfnOutput(self, "HttpsListenerArn", value=https_listener.listener_arn) | |
print(f"ALB listener on port {listener_port_https} defined.") | |
# if USE_CLOUDFRONT == 'True': | |
# # Add default action to the listener | |
# https_listener.add_action( | |
# "DefaultAction", # Logical ID for the default action | |
# action=elbv2.ListenerAction.fixed_response( | |
# status_code=403, | |
# content_type="text/plain", | |
# message_body="Access denied", | |
# ), | |
# ) | |
# # Add the Listener Rule for the specific CloudFront Host Header | |
# https_listener.add_action( | |
# "CloudFrontHostHeaderRuleHTTPS", | |
# action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration), | |
# priority=1, # Example priority. Adjust as needed. Lower is evaluated first. | |
# conditions=[ | |
# elbv2.ListenerCondition.host_headers([cloudfront_distribution_url]) | |
# ] | |
# ) | |
# else: | |
# https_listener.add_action( | |
# "CloudFrontHostHeaderRuleHTTPS", | |
# action=elbv2.ListenerAction.forward(target_groups=[target_group],stickiness_duration=cookie_duration)) | |
print("Added targets and actions to ALB HTTPS listener.") | |
except Exception as e: | |
raise Exception("Could not handle ALB target groups and listeners due to:", e) | |
# Create WAF to attach to load balancer | |
try: | |
web_acl_name = LOAD_BALANCER_WEB_ACL_NAME | |
if get_context_bool(f"exists:{web_acl_name}"): | |
# Lookup WAF ACL by ARN from context | |
web_acl_arn = get_context_str(f"arn:{web_acl_name}") | |
if not web_acl_arn: | |
raise ValueError(f"Context value 'arn:{web_acl_name}' is required if Web ACL exists.") | |
web_acl = create_web_acl_with_common_rules(self, web_acl_name, waf_scope="REGIONAL") # Assuming it takes scope and name | |
print(f"Handled ALB WAF web ACL {web_acl_name}.") | |
else: | |
web_acl = create_web_acl_with_common_rules(self, web_acl_name, waf_scope="REGIONAL") # Assuming it takes scope and name | |
print(f"Created ALB WAF web ACL {web_acl_name}.") | |
alb_waf_association = wafv2.CfnWebACLAssociation(self, id="alb_waf_association", resource_arn=alb.load_balancer_arn, web_acl_arn=web_acl.attr_arn) | |
except Exception as e: | |
raise Exception("Could not handle create ALB WAF web ACL due to:", e) | |
# --- Outputs for other stacks/regions --- | |
self.params = dict() | |
self.params["alb_arn_output"] = alb.load_balancer_arn | |
self.params["alb_security_group_id"] = alb_security_group.security_group_id | |
self.params["alb_dns_name"] = alb.load_balancer_dns_name | |
CfnOutput(self, "AlbArnOutput", | |
value=alb.load_balancer_arn, | |
description="ARN of the Application Load Balancer", | |
export_name=f"{self.stack_name}-AlbArn") # Export name must be unique within the account/region | |
CfnOutput(self, "AlbSecurityGroupIdOutput", | |
value=alb_security_group.security_group_id, | |
description="ID of the ALB's Security Group", | |
export_name=f"{self.stack_name}-AlbSgId") | |
CfnOutput(self, "ALBName", value=alb.load_balancer_name) | |
CfnOutput(self, "RegionalAlbDnsName", value=alb.load_balancer_dns_name) | |
CfnOutput(self, "CognitoPoolId", value=user_pool.user_pool_id) | |
# Add other outputs if needed | |
CfnOutput(self, "ECRRepoUri", value=ecr_repo.repository_uri) | |
# --- CLOUDFRONT DISTRIBUTION in separate stack (us-east-1 required) --- | |
class CdkStackCloudfront(Stack): | |
def __init__(self, scope: Construct, construct_id: str, alb_arn: str, alb_sec_group_id:str, alb_dns_name:str, **kwargs) -> None: | |
super().__init__(scope, construct_id, **kwargs) | |
# --- Helper to get context values --- | |
def get_context_bool(key: str, default: bool = False) -> bool: | |
return self.node.try_get_context(key) or default | |
def get_context_str(key: str, default: str = None) -> str: | |
return self.node.try_get_context(key) or default | |
def get_context_dict(scope: Construct, key: str, default: dict = None) -> dict: | |
return scope.node.try_get_context(key) or default | |
print(f"CloudFront Stack: Received ALB ARN: {alb_arn}") | |
print(f"CloudFront Stack: Received ALB Security Group ID: {alb_sec_group_id}") | |
if not alb_arn: | |
raise ValueError("ALB ARN must be provided to CloudFront stack") | |
if not alb_sec_group_id: | |
raise ValueError("ALB Security Group ID must be provided to CloudFront stack") | |
# 2. Import the ALB using its ARN | |
# This imports an existing ALB as a construct in the CloudFront stack's context. | |
# CloudFormation will understand this reference at deploy time. | |
alb = elbv2.ApplicationLoadBalancer.from_application_load_balancer_attributes( | |
self, "ImportedAlb", load_balancer_arn=alb_arn, security_group_id=alb_sec_group_id, load_balancer_dns_name=alb_dns_name | |
) | |
try: | |
web_acl_name = WEB_ACL_NAME | |
if get_context_bool(f"exists:{web_acl_name}"): | |
# Lookup WAF ACL by ARN from context | |
web_acl_arn = get_context_str(f"arn:{web_acl_name}") | |
if not web_acl_arn: | |
raise ValueError(f"Context value 'arn:{web_acl_name}' is required if Web ACL exists.") | |
web_acl = create_web_acl_with_common_rules(self, web_acl_name) # Assuming it takes scope and name | |
print(f"Handled Cloudfront WAF web ACL {web_acl_name}.") | |
else: | |
web_acl = create_web_acl_with_common_rules(self, web_acl_name) # Assuming it takes scope and name | |
print(f"Created Cloudfront WAF web ACL {web_acl_name}.") | |
# Add ALB as CloudFront Origin | |
origin = origins.LoadBalancerV2Origin( | |
alb, # Use the created or looked-up ALB object | |
custom_headers={CUSTOM_HEADER: CUSTOM_HEADER_VALUE}, | |
origin_shield_enabled=False, | |
protocol_policy=cloudfront.OriginProtocolPolicy.HTTP_ONLY, | |
) | |
if CLOUDFRONT_GEO_RESTRICTION: geo_restrict = cloudfront.GeoRestriction.allowlist(CLOUDFRONT_GEO_RESTRICTION) | |
else: geo_restrict = None | |
cloudfront_distribution = cloudfront.Distribution( | |
self, | |
"CloudFrontDistribution", # Logical ID | |
comment=CLOUDFRONT_DISTRIBUTION_NAME, # Use name as comment for easier identification | |
geo_restriction=geo_restrict, | |
default_behavior=cloudfront.BehaviorOptions( | |
origin=origin, | |
viewer_protocol_policy=cloudfront.ViewerProtocolPolicy.REDIRECT_TO_HTTPS, | |
allowed_methods=cloudfront.AllowedMethods.ALLOW_ALL, | |
cache_policy=cloudfront.CachePolicy.CACHING_DISABLED, | |
origin_request_policy=cloudfront.OriginRequestPolicy.ALL_VIEWER, | |
), | |
web_acl_id=web_acl.attr_arn | |
) | |
print(f"Cloudfront distribution {CLOUDFRONT_DISTRIBUTION_NAME} defined.") | |
except Exception as e: | |
raise Exception("Could not handle Cloudfront distribution due to:", e) | |
# --- Outputs --- | |
CfnOutput(self, "CloudFrontDistributionURL", | |
value=cloudfront_distribution.domain_name) | |