import boto3
import json
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

def lambda_handler(event, context):
    if event['detail']['status']=="PENDING":
        detail=event['detail']
        name=detail['name']
        celeryExecutorQueue=detail['celeryExecutorQueue']
        subnetIds=detail['networkConfiguration']['subnetIds']
        securityGroupIds=detail['networkConfiguration']['securityGroupIds']
        databaseVpcEndpointService=detail['databaseVpcEndpointService']

        # MWAA does not need to store the VPC ID, but we can get it from the subnets
        client = boto3.client('ec2')
        response = client.describe_subnets(SubnetIds=subnetIds)
        logger.info(response['Subnets'][0]['VpcId'])  
        vpcId=response['Subnets'][0]['VpcId']
        logger.info("vpcId: " + vpcId)       
        
        webserverVpcEndpointService=None
        if detail['webserverAccessMode']=="PRIVATE_ONLY":
            webserverVpcEndpointService=event['detail']['webserverVpcEndpointService']
        
        response = client.describe_vpc_endpoints(
            VpcEndpointIds=[],
            Filters=[
                {"Name": "vpc-id", "Values": [vpcId]},
                {"Name": "service-name", "Values": ["*.sqs"]},
                ],
            MaxResults=1000
        )
        sqsVpcEndpoint=None
        for r in response['VpcEndpoints']:
            if subnetIds[0] in r['SubnetIds'] or subnetIds[0] in r['SubnetIds']:
                # We are filtering describe by service name, so this must be SQS
                sqsVpcEndpoint=r
                break
        
        if sqsVpcEndpoint:
            logger.info("Found SQS endpoint: " + sqsVpcEndpoint['VpcEndpointId'])

            logger.info(sqsVpcEndpoint)
            pd = json.loads(sqsVpcEndpoint['PolicyDocument'])
            for s in pd['Statement']:
                if s['Effect']=='Allow':
                    resource = s['Resource']
                    logger.info(resource)
                    if '*' in resource:
                        logger.info("'*' already allowed")
                    elif celeryExecutorQueue in resource: 
                        logger.info("'"+celeryExecutorQueue+"' already allowed")                
                    else:
                        s['Resource'].append(celeryExecutorQueue)
                        logger.info("Updating SQS policy to " + str(pd))
        
                        client.modify_vpc_endpoint(
                            VpcEndpointId=sqsVpcEndpoint['VpcEndpointId'],
                            PolicyDocument=json.dumps(pd)
                            )
                    break
        
        # create MWAA database endpoint
        logger.info("creating endpoint to " + databaseVpcEndpointService)
        endpointName=name+"-database"
        response = client.create_vpc_endpoint(
            VpcEndpointType='Interface',
            VpcId=vpcId,
            ServiceName=databaseVpcEndpointService,
            SubnetIds=subnetIds,
            SecurityGroupIds=securityGroupIds,
            TagSpecifications=[
                {
                    "ResourceType": "vpc-endpoint",
                    "Tags": [
                        {
                            "Key": "Name",
                            "Value": endpointName
                        },
                    ]
                },
            ],           
        )
        logger.info("created VPCE: " + response['VpcEndpoint']['VpcEndpointId'])
            
        # create MWAA web server endpoint (if private)
        if webserverVpcEndpointService:
            endpointName=name+"-webserver"
            logger.info("creating endpoint to " + webserverVpcEndpointService)
            response = client.create_vpc_endpoint(
                VpcEndpointType='Interface',
                VpcId=vpcId,
                ServiceName=webserverVpcEndpointService,
                SubnetIds=subnetIds,
                SecurityGroupIds=securityGroupIds,
                TagSpecifications=[
                    {
                        "ResourceType": "vpc-endpoint",
                        "Tags": [
                            {
                                "Key": "Name",
                                "Value": endpointName
                            },
                        ]
                    },
                ],                  
            )
            logger.info("created VPCE: " + response['VpcEndpoint']['VpcEndpointId'])

    return {
        'statusCode': 200,
        'body': json.dumps(event['detail']['status'])
    }