[feat] re-add samgeo with segment-anything-fast
Browse files- README.md +3 -0
 - dockerfiles/dockerfile-lambda-gdal-runner +7 -2
 - events/example_output.json +19 -0
 - requirements.txt +3 -0
 - requirements_dev.txt +1 -0
 - src/__init__.py +4 -0
 - src/app.py +66 -40
 - src/main.py +54 -144
 - src/prediction_api/predictor.py +5 -6
 - src/utilities/constants.py +3 -1
 
    	
        README.md
    CHANGED
    
    | 
         @@ -9,6 +9,9 @@ docker stop $(docker ps -a -q); docker rm $(docker ps -a -q) 
     | 
|
| 9 | 
         
             
            # build the base docker image with the docker aws repository tag
         
     | 
| 10 | 
         
             
            docker build . -f dockerfiles/dockerfile-lambda-gdal-runner --tag 686901913580.dkr.ecr.eu-west-1.amazonaws.com/lambda-gdal-runner
         
     | 
| 11 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            # build the final docker image
         
     | 
| 13 | 
         
             
            docker build . -f dockerfiles/dockerfile-lambda-samgeo-api --tag 686901913580.dkr.ecr.eu-west-1.amazonaws.com/lambda-samgeo-api
         
     | 
| 14 | 
         
             
            ```
         
     | 
| 
         | 
|
| 9 | 
         
             
            # build the base docker image with the docker aws repository tag
         
     | 
| 10 | 
         
             
            docker build . -f dockerfiles/dockerfile-lambda-gdal-runner --tag 686901913580.dkr.ecr.eu-west-1.amazonaws.com/lambda-gdal-runner
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            # OPTIONAL: to build the lambda-gdal-runner image on a x86 machine use the build arg `RIE="https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie"`:
         
     | 
| 13 | 
         
            +
            docker build . -f dockerfiles/dockerfile-lambda-gdal-runner --build-arg RIE="https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie" --tag 686901913580.dkr.ecr.eu-west-1.amazonaws.com/lambda-gdal-runner --progress=plain
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
             
            # build the final docker image
         
     | 
| 16 | 
         
             
            docker build . -f dockerfiles/dockerfile-lambda-samgeo-api --tag 686901913580.dkr.ecr.eu-west-1.amazonaws.com/lambda-samgeo-api
         
     | 
| 17 | 
         
             
            ```
         
     | 
    	
        dockerfiles/dockerfile-lambda-gdal-runner
    CHANGED
    
    | 
         @@ -5,12 +5,17 @@ ARG LAMBDA_TASK_ROOT="/var/task" 
     | 
|
| 5 | 
         
             
            ARG PYTHONPATH="${LAMBDA_TASK_ROOT}:${PYTHONPATH}:/usr/local/lib/python3/dist-packages"
         
     | 
| 6 | 
         
             
            ARG RIE="https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-arm64"
         
     | 
| 7 | 
         | 
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            # Set working directory to function root directory
         
     | 
| 9 | 
         
             
            WORKDIR ${LAMBDA_TASK_ROOT}
         
     | 
| 10 | 
         
            -
            COPY  
     | 
| 11 | 
         | 
| 12 | 
         
             
            RUN apt update && apt install -y curl python3-pip
         
     | 
| 13 | 
         
            -
            RUN python 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 14 | 
         | 
| 15 | 
         
             
            RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
         
     | 
| 16 | 
         
             
            RUN chmod +x /usr/local/bin/aws-lambda-rie
         
     | 
| 
         | 
|
| 5 | 
         
             
            ARG PYTHONPATH="${LAMBDA_TASK_ROOT}:${PYTHONPATH}:/usr/local/lib/python3/dist-packages"
         
     | 
| 6 | 
         
             
            ARG RIE="https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-arm64"
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            RUN echo "ENV RIE: $RIE ..."
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
             
            # Set working directory to function root directory
         
     | 
| 11 | 
         
             
            WORKDIR ${LAMBDA_TASK_ROOT}
         
     | 
| 12 | 
         
            +
            COPY requirements.txt ${LAMBDA_TASK_ROOT}/requirements.txt
         
     | 
| 13 | 
         | 
| 14 | 
         
             
            RUN apt update && apt install -y curl python3-pip
         
     | 
| 15 | 
         
            +
            RUN which python
         
     | 
| 16 | 
         
            +
            RUN python --version
         
     | 
| 17 | 
         
            +
            RUN python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
         
     | 
| 18 | 
         
            +
            RUN python -m pip install -r ${LAMBDA_TASK_ROOT}/requirements.txt --target ${LAMBDA_TASK_ROOT}
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            RUN curl -Lo /usr/local/bin/aws-lambda-rie ${RIE}
         
     | 
| 21 | 
         
             
            RUN chmod +x /usr/local/bin/aws-lambda-rie
         
     | 
    	
        events/example_output.json
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "type": "FeatureCollection",
         
     | 
| 3 | 
         
            +
              "features": [
         
     | 
| 4 | 
         
            +
                {
         
     | 
| 5 | 
         
            +
                  "type": "Feature",
         
     | 
| 6 | 
         
            +
                  "geometry": {
         
     | 
| 7 | 
         
            +
                    "type": "Polygon",
         
     | 
| 8 | 
         
            +
                    "coordinates": [
         
     | 
| 9 | 
         
            +
                      [
         
     | 
| 10 | 
         
            +
                        [46.143, 9.361],
         
     | 
| 11 | 
         
            +
                        [46.151, 9.401],
         
     | 
| 12 | 
         
            +
                        [46.137, 9.353]
         
     | 
| 13 | 
         
            +
                      ]
         
     | 
| 14 | 
         
            +
                    ]
         
     | 
| 15 | 
         
            +
                  },
         
     | 
| 16 | 
         
            +
                  "properties": {"name": "Colico"}
         
     | 
| 17 | 
         
            +
                }
         
     | 
| 18 | 
         
            +
              ]
         
     | 
| 19 | 
         
            +
            }
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -1,4 +1,7 @@ 
     | 
|
| 
         | 
|
| 1 | 
         
             
            aws-lambda-powertools
         
     | 
| 2 | 
         
             
            bson
         
     | 
| 
         | 
|
| 3 | 
         
             
            python-dotenv
         
     | 
| 
         | 
|
| 4 | 
         
             
            segment-geospatial
         
     | 
| 
         | 
|
| 1 | 
         
            +
            awslambdaric
         
     | 
| 2 | 
         
             
            aws-lambda-powertools
         
     | 
| 3 | 
         
             
            bson
         
     | 
| 4 | 
         
            +
            geojson-pydantic
         
     | 
| 5 | 
         
             
            python-dotenv
         
     | 
| 6 | 
         
            +
            segment-anything-fast
         
     | 
| 7 | 
         
             
            segment-geospatial
         
     | 
    	
        requirements_dev.txt
    CHANGED
    
    | 
         @@ -1,6 +1,7 @@ 
     | 
|
| 1 | 
         
             
            awslambdaric
         
     | 
| 2 | 
         
             
            aws_lambda_powertools
         
     | 
| 3 | 
         
             
            fastjsonschema
         
     | 
| 
         | 
|
| 4 | 
         
             
            jmespath
         
     | 
| 5 | 
         
             
            pydantic
         
     | 
| 6 | 
         
             
            requests
         
     | 
| 
         | 
|
| 1 | 
         
             
            awslambdaric
         
     | 
| 2 | 
         
             
            aws_lambda_powertools
         
     | 
| 3 | 
         
             
            fastjsonschema
         
     | 
| 4 | 
         
            +
            geojson-pydantic
         
     | 
| 5 | 
         
             
            jmespath
         
     | 
| 6 | 
         
             
            pydantic
         
     | 
| 7 | 
         
             
            requests
         
     | 
    	
        src/__init__.py
    CHANGED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from aws_lambda_powertools import Logger
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            app_logger = Logger()
         
     | 
    	
        src/app.py
    CHANGED
    
    | 
         @@ -1,28 +1,42 @@ 
     | 
|
| 1 | 
         
             
            import json
         
     | 
| 2 | 
         
             
            import time
         
     | 
| 3 | 
         
             
            from http import HTTPStatus
         
     | 
| 4 | 
         
            -
            from  
     | 
| 
         | 
|
| 5 | 
         
             
            from aws_lambda_powertools.event_handler import content_types
         
     | 
| 6 | 
         
             
            from aws_lambda_powertools.utilities.typing import LambdaContext
         
     | 
| 
         | 
|
| 7 | 
         
             
            from pydantic import BaseModel, ValidationError
         
     | 
| 8 | 
         | 
| 9 | 
         
            -
            from src 
     | 
| 10 | 
         
            -
            from src. 
     | 
| 
         | 
|
| 11 | 
         
             
            from src.utilities.utilities import base64_decode
         
     | 
| 12 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
            -
            class  
     | 
| 18 | 
         
            -
                 
     | 
| 19 | 
         
            -
                 
     | 
| 20 | 
         
            -
                duration_run: float 
     | 
| 21 | 
         
            -
                message: str 
     | 
| 22 | 
         
            -
                request_id: str = ""
         
     | 
| 23 | 
         | 
| 24 | 
         | 
| 25 | 
         
            -
            def get_response(status: int, start_time: float, request_id: str,  
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         
             
                Return a response for frontend clients.
         
     | 
| 28 | 
         | 
| 
         @@ -30,68 +44,80 @@ def get_response(status: int, start_time: float, request_id: str, output: BBoxWi 
     | 
|
| 30 | 
         
             
                    status: status response
         
     | 
| 31 | 
         
             
                    start_time: request start time (float)
         
     | 
| 32 | 
         
             
                    request_id: str
         
     | 
| 33 | 
         
            -
                     
     | 
| 34 | 
         | 
| 35 | 
         
             
                Returns:
         
     | 
| 36 | 
         
            -
                     
     | 
| 37 | 
         | 
| 38 | 
         
             
                """
         
     | 
| 39 | 
         
            -
                duration_run = time.time() - start_time
         
     | 
| 40 | 
         
            -
                 
     | 
| 41 | 
         
            -
                 
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                    output.message = CUSTOM_RESPONSE_MESSAGES[status]
         
     | 
| 44 | 
         
            -
                    output.request_id = request_id
         
     | 
| 45 | 
         
            -
                    body = output.model_dump_json()
         
     | 
| 46 | 
         
            -
                elif status == 200:
         
     | 
| 47 | 
         
            -
                    # should never be here...
         
     | 
| 48 | 
         
            -
                    raise KeyError("status 200, but missing BBoxWithPointInput argument.")
         
     | 
| 49 | 
         
             
                response = {
         
     | 
| 50 | 
         
             
                    "statusCode": status,
         
     | 
| 51 | 
         
            -
                    "header": {"Content-Type": content_types.APPLICATION_JSON 
     | 
| 52 | 
         
            -
                    "body":  
     | 
| 53 | 
         
             
                    "isBase64Encoded": False
         
     | 
| 54 | 
         
             
                }
         
     | 
| 55 | 
         
            -
                 
     | 
| 56 | 
         
             
                return json.dumps(response)
         
     | 
| 57 | 
         | 
| 58 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 59 | 
         
             
            def lambda_handler(event: dict, context: LambdaContext):
         
     | 
| 60 | 
         
            -
                 
     | 
| 61 | 
         
             
                start_time = time.time()
         
     | 
| 62 | 
         | 
| 63 | 
         
             
                if "version" in event:
         
     | 
| 64 | 
         
            -
                     
     | 
| 65 | 
         | 
| 66 | 
         
             
                try:
         
     | 
| 67 | 
         
            -
                     
     | 
| 68 | 
         
            -
                     
     | 
| 69 | 
         | 
| 70 | 
         
             
                    try:
         
     | 
| 71 | 
         
             
                        body = event["body"]
         
     | 
| 72 | 
         
             
                    except Exception as e_constants1:
         
     | 
| 73 | 
         
            -
                         
     | 
| 74 | 
         
             
                        body = event
         
     | 
| 75 | 
         | 
| 76 | 
         
            -
                     
     | 
| 77 | 
         | 
| 78 | 
         
             
                    if isinstance(body, str):
         
     | 
| 79 | 
         
             
                        body_decoded_str = base64_decode(body)
         
     | 
| 80 | 
         
            -
                         
     | 
| 81 | 
         
             
                        body = json.loads(body_decoded_str)
         
     | 
| 82 | 
         | 
| 83 | 
         
            -
                     
     | 
| 84 | 
         | 
| 85 | 
         
             
                    try:
         
     | 
| 86 | 
         
            -
                         
     | 
| 87 | 
         
            -
                         
     | 
| 88 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 89 | 
         
             
                    except ValidationError as ve:
         
     | 
| 90 | 
         
            -
                         
     | 
| 91 | 
         
             
                        response = get_response(HTTPStatus.UNPROCESSABLE_ENTITY.value, start_time, context.aws_request_id)
         
     | 
| 92 | 
         
             
                except Exception as e:
         
     | 
| 93 | 
         
            -
                     
     | 
| 94 | 
         
             
                    response = get_response(HTTPStatus.INTERNAL_SERVER_ERROR.value, start_time, context.aws_request_id)
         
     | 
| 95 | 
         | 
| 96 | 
         
            -
                 
     | 
| 97 | 
         
             
                return response
         
     | 
| 
         | 
|
| 1 | 
         
             
            import json
         
     | 
| 2 | 
         
             
            import time
         
     | 
| 3 | 
         
             
            from http import HTTPStatus
         
     | 
| 4 | 
         
            +
            from typing import Dict, List
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
             
            from aws_lambda_powertools.event_handler import content_types
         
     | 
| 7 | 
         
             
            from aws_lambda_powertools.utilities.typing import LambdaContext
         
     | 
| 8 | 
         
            +
            from geojson_pydantic import FeatureCollection, Feature, Polygon
         
     | 
| 9 | 
         
             
            from pydantic import BaseModel, ValidationError
         
     | 
| 10 | 
         | 
| 11 | 
         
            +
            from src import app_logger
         
     | 
| 12 | 
         
            +
            from src.prediction_api.predictor import base_predict
         
     | 
| 13 | 
         
            +
            from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, MODEL_NAME, ZOOM
         
     | 
| 14 | 
         
             
            from src.utilities.utilities import base64_decode
         
     | 
| 15 | 
         | 
| 16 | 
         
            +
            PolygonFeatureCollectionModel = FeatureCollection[Feature[Polygon, Dict]]
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class LatLngTupleLeaflet(BaseModel):
         
     | 
| 20 | 
         
            +
                lat: float
         
     | 
| 21 | 
         
            +
                lng: float
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         | 
| 24 | 
         
            +
            class RequestBody(BaseModel):
         
     | 
| 25 | 
         
            +
                ne: LatLngTupleLeaflet
         
     | 
| 26 | 
         
            +
                sw: LatLngTupleLeaflet
         
     | 
| 27 | 
         
            +
                points: List[LatLngTupleLeaflet]
         
     | 
| 28 | 
         
            +
                model: str = MODEL_NAME
         
     | 
| 29 | 
         
            +
                zoom: float = ZOOM
         
     | 
| 30 | 
         | 
| 31 | 
         | 
| 32 | 
         
            +
            class ResponseBody(BaseModel):
         
     | 
| 33 | 
         
            +
                geojson: Dict = None
         
     | 
| 34 | 
         
            +
                request_id: str
         
     | 
| 35 | 
         
            +
                duration_run: float
         
     | 
| 36 | 
         
            +
                message: str
         
     | 
| 
         | 
|
| 37 | 
         | 
| 38 | 
         | 
| 39 | 
         
            +
            def get_response(status: int, start_time: float, request_id: str, response_body: ResponseBody = None) -> str:
         
     | 
| 40 | 
         
             
                """
         
     | 
| 41 | 
         
             
                Return a response for frontend clients.
         
     | 
| 42 | 
         | 
| 
         | 
|
| 44 | 
         
             
                    status: status response
         
     | 
| 45 | 
         
             
                    start_time: request start time (float)
         
     | 
| 46 | 
         
             
                    request_id: str
         
     | 
| 47 | 
         
            +
                    response_body: dict we embed into our response
         
     | 
| 48 | 
         | 
| 49 | 
         
             
                Returns:
         
     | 
| 50 | 
         
            +
                    str: json response
         
     | 
| 51 | 
         | 
| 52 | 
         
             
                """
         
     | 
| 53 | 
         
            +
                response_body.duration_run = time.time() - start_time
         
     | 
| 54 | 
         
            +
                response_body.message = CUSTOM_RESPONSE_MESSAGES[status]
         
     | 
| 55 | 
         
            +
                response_body.request_id = request_id
         
     | 
| 56 | 
         
            +
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 57 | 
         
             
                response = {
         
     | 
| 58 | 
         
             
                    "statusCode": status,
         
     | 
| 59 | 
         
            +
                    "header": {"Content-Type": content_types.APPLICATION_JSON},
         
     | 
| 60 | 
         
            +
                    "body": response_body.model_dump_json(),
         
     | 
| 61 | 
         
             
                    "isBase64Encoded": False
         
     | 
| 62 | 
         
             
                }
         
     | 
| 63 | 
         
            +
                app_logger.info(f"response type:{type(response)} => {response}.")
         
     | 
| 64 | 
         
             
                return json.dumps(response)
         
     | 
| 65 | 
         | 
| 66 | 
         | 
| 67 | 
         
            +
            def get_parsed_bbox_points(request_input: RequestBody) -> Dict:
         
     | 
| 68 | 
         
            +
                return {
         
     | 
| 69 | 
         
            +
                    "bbox": [
         
     | 
| 70 | 
         
            +
                        request_input.ne.lat, request_input.sw.lat,
         
     | 
| 71 | 
         
            +
                        request_input.ne.lng, request_input.sw.lng
         
     | 
| 72 | 
         
            +
                    ],
         
     | 
| 73 | 
         
            +
                    "points": [[p.lat, p.lng] for p in request_input.points]
         
     | 
| 74 | 
         
            +
                }
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
             
            def lambda_handler(event: dict, context: LambdaContext):
         
     | 
| 78 | 
         
            +
                app_logger.info(f"start with aws_request_id:{context.aws_request_id}.")
         
     | 
| 79 | 
         
             
                start_time = time.time()
         
     | 
| 80 | 
         | 
| 81 | 
         
             
                if "version" in event:
         
     | 
| 82 | 
         
            +
                    app_logger.info(f"event version: {event['version']}.")
         
     | 
| 83 | 
         | 
| 84 | 
         
             
                try:
         
     | 
| 85 | 
         
            +
                    app_logger.info(f"event:{json.dumps(event)}...")
         
     | 
| 86 | 
         
            +
                    app_logger.info(f"context:{context}...")
         
     | 
| 87 | 
         | 
| 88 | 
         
             
                    try:
         
     | 
| 89 | 
         
             
                        body = event["body"]
         
     | 
| 90 | 
         
             
                    except Exception as e_constants1:
         
     | 
| 91 | 
         
            +
                        app_logger.error(f"e_constants1:{e_constants1}.")
         
     | 
| 92 | 
         
             
                        body = event
         
     | 
| 93 | 
         | 
| 94 | 
         
            +
                    app_logger.info(f"body: {type(body)}, {body}...")
         
     | 
| 95 | 
         | 
| 96 | 
         
             
                    if isinstance(body, str):
         
     | 
| 97 | 
         
             
                        body_decoded_str = base64_decode(body)
         
     | 
| 98 | 
         
            +
                        app_logger.info(f"body_decoded_str: {type(body_decoded_str)}, {body_decoded_str}...")
         
     | 
| 99 | 
         
             
                        body = json.loads(body_decoded_str)
         
     | 
| 100 | 
         | 
| 101 | 
         
            +
                    app_logger.info(f"body:{body}...")
         
     | 
| 102 | 
         | 
| 103 | 
         
             
                    try:
         
     | 
| 104 | 
         
            +
                        model_name = body["model"] if "model" in body else MODEL_NAME
         
     | 
| 105 | 
         
            +
                        zoom = body["zoom"] if "zoom" in body else ZOOM
         
     | 
| 106 | 
         
            +
                        body_request_validated = RequestBody(ne=body["ne"], sw=body["sw"], points=body["points"], model=model_name, zoom=zoom)
         
     | 
| 107 | 
         
            +
                        body_request = get_parsed_bbox_points(body_request_validated)
         
     | 
| 108 | 
         
            +
                        app_logger.info(f"validation ok - body_request:{body_request}, starting prediction...")
         
     | 
| 109 | 
         
            +
                        output_geojson_dict = base_predict(bbox=body_request["bbox"], model_name=body_request_validated["model"], zoom=body_request_validated["zoom"])
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                        # raise ValidationError in case this is not a valid geojson by GeoJSON specification rfc7946
         
     | 
| 112 | 
         
            +
                        PolygonFeatureCollectionModel(**output_geojson_dict)
         
     | 
| 113 | 
         
            +
                        body_response = ResponseBody(geojson=output_geojson_dict)
         
     | 
| 114 | 
         
            +
                        response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
         
     | 
| 115 | 
         
             
                    except ValidationError as ve:
         
     | 
| 116 | 
         
            +
                        app_logger.error(f"validation error:{ve}.")
         
     | 
| 117 | 
         
             
                        response = get_response(HTTPStatus.UNPROCESSABLE_ENTITY.value, start_time, context.aws_request_id)
         
     | 
| 118 | 
         
             
                except Exception as e:
         
     | 
| 119 | 
         
            +
                    app_logger.error(f"exception:{e}.")
         
     | 
| 120 | 
         
             
                    response = get_response(HTTPStatus.INTERNAL_SERVER_ERROR.value, start_time, context.aws_request_id)
         
     | 
| 121 | 
         | 
| 122 | 
         
            +
                app_logger.info(f"response_dumped:{response}...")
         
     | 
| 123 | 
         
             
                return response
         
     | 
    	
        src/main.py
    CHANGED
    
    | 
         @@ -1,147 +1,57 @@ 
     | 
|
| 1 | 
         
            -
            import  
     | 
| 2 | 
         
            -
            import uuid
         
     | 
| 3 | 
         | 
| 4 | 
         
            -
            from  
     | 
| 5 | 
         
            -
            from fastapi.exceptions import RequestValidationError
         
     | 
| 6 | 
         
            -
            from fastapi.responses import FileResponse, JSONResponse
         
     | 
| 7 | 
         
            -
            from fastapi.staticfiles import StaticFiles
         
     | 
| 8 | 
         
             
            from pydantic import BaseModel
         
     | 
| 9 | 
         | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
                     
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
                         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
                         
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
            -
             
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                 
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
                app_logger.info(f"start:{request_input}.")
         
     | 
| 63 | 
         
            -
                request_body = get_parsed_bbox_points(request_input)
         
     | 
| 64 | 
         
            -
                app_logger.info(f"request_body:{request_body}.")
         
     | 
| 65 | 
         
            -
                return JSONResponse(
         
     | 
| 66 | 
         
            -
                    status_code=200,
         
     | 
| 67 | 
         
            -
                    content=request_body
         
     | 
| 68 | 
         
            -
                )
         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
            @app.get("/hello")
         
     | 
| 72 | 
         
            -
            async def hello() -> JSONResponse:
         
     | 
| 73 | 
         
            -
                app_logger.info(f"hello")
         
     | 
| 74 | 
         
            -
                return JSONResponse(status_code=200, content={"msg": "hello"})
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
            @app.post("/infer_samgeo")
         
     | 
| 78 | 
         
            -
            def samgeo(request_input: Input):
         
     | 
| 79 | 
         
            -
                import subprocess
         
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
            -
                from src.prediction_api.predictor import base_predict
         
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
                app_logger.info("starting inference request...")
         
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
                try:
         
     | 
| 86 | 
         
            -
                    import time
         
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
                    time_start_run = time.time()
         
     | 
| 89 | 
         
            -
                    request_body = get_parsed_bbox_points(request_input)
         
     | 
| 90 | 
         
            -
                    app_logger.info(f"request_body:{request_body}.")
         
     | 
| 91 | 
         
            -
                    try:
         
     | 
| 92 | 
         
            -
                        output = base_predict(
         
     | 
| 93 | 
         
            -
                            bbox=request_body["bbox"],
         
     | 
| 94 | 
         
            -
                            point_coords=request_body["point"]
         
     | 
| 95 | 
         
            -
                        )
         
     | 
| 96 | 
         
            -
                        duration_run = time.time() - time_start_run
         
     | 
| 97 | 
         
            -
                        app_logger.info(f"duration_run:{duration_run}.")
         
     | 
| 98 | 
         
            -
                        body = {
         
     | 
| 99 | 
         
            -
                            "duration_run": duration_run,
         
     | 
| 100 | 
         
            -
                            "output": output
         
     | 
| 101 | 
         
            -
                        }
         
     | 
| 102 | 
         
            -
                        return JSONResponse(status_code=200, content={"body": json.dumps(body)})
         
     | 
| 103 | 
         
            -
                    except Exception as inference_exception:
         
     | 
| 104 | 
         
            -
                        home_content = subprocess.run("ls -l /home/user", shell=True, universal_newlines=True, stdout=subprocess.PIPE)
         
     | 
| 105 | 
         
            -
                        app_logger.error(f"/home/user ls -l: {home_content.stdout}.")
         
     | 
| 106 | 
         
            -
                        app_logger.error(f"inference error:{inference_exception}.")
         
     | 
| 107 | 
         
            -
                        return HTTPException(status_code=500, detail="Internal server error on inference")
         
     | 
| 108 | 
         
            -
                except Exception as generic_exception:
         
     | 
| 109 | 
         
            -
                    app_logger.error(f"generic error:{generic_exception}.")
         
     | 
| 110 | 
         
            -
                    return HTTPException(status_code=500, detail="Generic internal server error")
         
     | 
| 111 | 
         
            -
             
     | 
| 112 | 
         
            -
             
     | 
| 113 | 
         
            -
            @app.exception_handler(RequestValidationError)
         
     | 
| 114 | 
         
            -
            async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
         
     | 
| 115 | 
         
            -
                app_logger.error(f"exception errors: {exc.errors()}.")
         
     | 
| 116 | 
         
            -
                app_logger.error(f"exception body: {exc}.")
         
     | 
| 117 | 
         
            -
                headers = request.headers.items()
         
     | 
| 118 | 
         
            -
                app_logger.error(f"request header: {dict(headers)}.")
         
     | 
| 119 | 
         
            -
                params = request.query_params.items()
         
     | 
| 120 | 
         
            -
                app_logger.error(f'request query params: {dict(params)}.')
         
     | 
| 121 | 
         
            -
                return JSONResponse(
         
     | 
| 122 | 
         
            -
                    status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
         
     | 
| 123 | 
         
            -
                    content={"msg": "Error - Unprocessable Entity"}
         
     | 
| 124 | 
         
            -
                )
         
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
            @app.exception_handler(HTTPException)
         
     | 
| 128 | 
         
            -
            async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
         
     | 
| 129 | 
         
            -
                app_logger.error(f"exception: {str(exc)}.")
         
     | 
| 130 | 
         
            -
                headers = request.headers.items()
         
     | 
| 131 | 
         
            -
                app_logger.error(f'request header: {dict(headers)}.' )
         
     | 
| 132 | 
         
            -
                params = request.query_params.items()
         
     | 
| 133 | 
         
            -
                app_logger.error(f'request query params: {dict(params)}.')
         
     | 
| 134 | 
         
            -
                return JSONResponse(
         
     | 
| 135 | 
         
            -
                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
         
     | 
| 136 | 
         
            -
                    content={"msg": "Error - Internal Server Error"}
         
     | 
| 137 | 
         
            -
                )
         
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         
            -
             
     | 
| 140 | 
         
            -
            # important: the index() function and the app.mount MUST be at the end
         
     | 
| 141 | 
         
            -
            app.mount("/", StaticFiles(directory="static", html=True), name="static")
         
     | 
| 142 | 
         
            -
             
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
            @app.get("/")
         
     | 
| 145 | 
         
            -
            def index() -> FileResponse:
         
     | 
| 146 | 
         
            -
                return FileResponse(path="/app/static/index.html", media_type="text/html")
         
     | 
| 147 | 
         
            -
             
     | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict
         
     | 
| 
         | 
|
| 2 | 
         | 
| 3 | 
         
            +
            from geojson_pydantic import Feature, Polygon, FeatureCollection
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 4 | 
         
             
            from pydantic import BaseModel
         
     | 
| 5 | 
         | 
| 6 | 
         
            +
            g1 = {
         
     | 
| 7 | 
         
            +
                "type": "FeatureCollection",
         
     | 
| 8 | 
         
            +
                "features": [{
         
     | 
| 9 | 
         
            +
                    "type": "Feature",
         
     | 
| 10 | 
         
            +
                    "geometry": {
         
     | 
| 11 | 
         
            +
                        "type": "Polygon",
         
     | 
| 12 | 
         
            +
                        "coordinates": [
         
     | 
| 13 | 
         
            +
                            [
         
     | 
| 14 | 
         
            +
                                [13.1, 52.46385],
         
     | 
| 15 | 
         
            +
                                [13.42786, 52.6],
         
     | 
| 16 | 
         
            +
                                [13.2, 52.5],
         
     | 
| 17 | 
         
            +
                                [13.38272, 52.4],
         
     | 
| 18 | 
         
            +
                                [13.43, 52.46385],
         
     | 
| 19 | 
         
            +
                                [13.1, 52.46385]
         
     | 
| 20 | 
         
            +
                            ]
         
     | 
| 21 | 
         
            +
                        ],
         
     | 
| 22 | 
         
            +
                    },
         
     | 
| 23 | 
         
            +
                    "properties": {
         
     | 
| 24 | 
         
            +
                        "name": "uno",
         
     | 
| 25 | 
         
            +
                    },
         
     | 
| 26 | 
         
            +
                }, {
         
     | 
| 27 | 
         
            +
                    "type": "Feature",
         
     | 
| 28 | 
         
            +
                    "geometry": {
         
     | 
| 29 | 
         
            +
                        "type": "Polygon",
         
     | 
| 30 | 
         
            +
                        "coordinates": [
         
     | 
| 31 | 
         
            +
                            [
         
     | 
| 32 | 
         
            +
                                [13.77, 52.8],
         
     | 
| 33 | 
         
            +
                                [13.88, 52.77],
         
     | 
| 34 | 
         
            +
                                [13.99, 52.66],
         
     | 
| 35 | 
         
            +
                                [13.11, 52.55],
         
     | 
| 36 | 
         
            +
                                [13.33, 52.44],
         
     | 
| 37 | 
         
            +
                                [13.77, 52.8]
         
     | 
| 38 | 
         
            +
                            ]
         
     | 
| 39 | 
         
            +
                        ],
         
     | 
| 40 | 
         
            +
                    },
         
     | 
| 41 | 
         
            +
                    "properties": {
         
     | 
| 42 | 
         
            +
                        "name": "due",
         
     | 
| 43 | 
         
            +
                    },
         
     | 
| 44 | 
         
            +
                }]
         
     | 
| 45 | 
         
            +
            }
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            PolygonFeatureCollectionModel = FeatureCollection[Feature[Polygon, Dict]]
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 50 | 
         
            +
                feat = PolygonFeatureCollectionModel(**g1)
         
     | 
| 51 | 
         
            +
                print(feat)
         
     | 
| 52 | 
         
            +
                print("feat")
         
     | 
| 53 | 
         
            +
                """
         
     | 
| 54 | 
         
            +
                
         
     | 
| 55 | 
         
            +
                point:  {"lat":12.425847783029134,"lng":53.887939453125} 
         
     | 
| 56 | 
         
            +
                map:ne:{"lat":17.895114303749143,"lng":58.27148437500001} sw:{"lat":0.6591651462894632,"lng":34.01367187500001}.
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/prediction_api/predictor.py
    CHANGED
    
    | 
         @@ -2,15 +2,14 @@ 
     | 
|
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            from src import app_logger
         
     | 
| 5 | 
         
            -
            from src.utilities.constants import ROOT
         
     | 
| 6 | 
         
             
            from src.utilities.type_hints import input_floatlist, input_floatlist2
         
     | 
| 7 | 
         | 
| 8 | 
         | 
| 9 | 
         
            -
            def base_predict(
         
     | 
| 10 | 
         
            -
                    bbox: input_floatlist, point_coords: input_floatlist2, point_crs: str = "EPSG:4326", zoom: float = 16, model_name: str = "vit_h", root_folder: str = ROOT
         
     | 
| 11 | 
         
            -
            ) -> str:
         
     | 
| 12 | 
         
             
                import tempfile
         
     | 
| 13 | 
         
            -
                from samgeo import  
     | 
| 
         | 
|
| 14 | 
         | 
| 15 | 
         
             
                with tempfile.NamedTemporaryFile(prefix="satellite_", suffix=".tif", dir=root_folder) as image_input_tmp:
         
     | 
| 16 | 
         
             
                    app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
         
     | 
| 
         @@ -32,7 +31,7 @@ def base_predict( 
     | 
|
| 32 | 
         | 
| 33 | 
         
             
                    with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
         
     | 
| 34 | 
         
             
                        app_logger.info(f"done set_image, start prediction using {image_output_tmp.name} as output...")
         
     | 
| 35 | 
         
            -
                        predictor. 
     | 
| 36 | 
         | 
| 37 | 
         
             
                        # geotiff to geojson
         
     | 
| 38 | 
         
             
                        with tempfile.NamedTemporaryFile(prefix="feats_", suffix=".geojson", dir=root_folder) as vector_tmp:
         
     | 
| 
         | 
|
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            from src import app_logger
         
     | 
| 5 | 
         
            +
            from src.utilities.constants import ROOT, MODEL_NAME, ZOOM
         
     | 
| 6 | 
         
             
            from src.utilities.type_hints import input_floatlist, input_floatlist2
         
     | 
| 7 | 
         | 
| 8 | 
         | 
| 9 | 
         
            +
            def base_predict(bbox: input_floatlist, zoom: float = ZOOM, model_name: str = MODEL_NAME, root_folder: str = ROOT) -> dict:
         
     | 
| 
         | 
|
| 
         | 
|
| 10 | 
         
             
                import tempfile
         
     | 
| 11 | 
         
            +
                from samgeo import tms_to_geotiff
         
     | 
| 12 | 
         
            +
                from samgeo.fast_sam import SamGeo
         
     | 
| 13 | 
         | 
| 14 | 
         
             
                with tempfile.NamedTemporaryFile(prefix="satellite_", suffix=".tif", dir=root_folder) as image_input_tmp:
         
     | 
| 15 | 
         
             
                    app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
         
     | 
| 
         | 
|
| 31 | 
         | 
| 32 | 
         
             
                    with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
         
     | 
| 33 | 
         
             
                        app_logger.info(f"done set_image, start prediction using {image_output_tmp.name} as output...")
         
     | 
| 34 | 
         
            +
                        predictor.everything_prompt(output=image_output_tmp.name)
         
     | 
| 35 | 
         | 
| 36 | 
         
             
                        # geotiff to geojson
         
     | 
| 37 | 
         
             
                        with tempfile.NamedTemporaryFile(prefix="feats_", suffix=".geojson", dir=root_folder) as vector_tmp:
         
     | 
    	
        src/utilities/constants.py
    CHANGED
    
    | 
         @@ -21,4 +21,6 @@ CUSTOM_RESPONSE_MESSAGES = { 
     | 
|
| 21 | 
         
             
                200: "ok",
         
     | 
| 22 | 
         
             
                422: "Missing required parameter",
         
     | 
| 23 | 
         
             
                500: "Internal server error"
         
     | 
| 24 | 
         
            -
            }
         
     | 
| 
         | 
|
| 
         | 
| 
         | 
|
| 21 | 
         
             
                200: "ok",
         
     | 
| 22 | 
         
             
                422: "Missing required parameter",
         
     | 
| 23 | 
         
             
                500: "Internal server error"
         
     | 
| 24 | 
         
            +
            }
         
     | 
| 25 | 
         
            +
            MODEL_NAME = "FastSAM-s.pt"
         
     | 
| 26 | 
         
            +
            ZOOM = 13
         
     |