File size: 9,394 Bytes
87073e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab78607
87073e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee2b25
87073e6
 
2ee2b25
 
87073e6
2ee2b25
 
 
 
 
 
 
 
 
 
 
 
 
 
87073e6
 
 
368fc0b
87073e6
 
 
 
 
 
 
 
 
368fc0b
87073e6
 
368fc0b
87073e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab78607
87073e6
 
 
 
 
 
ab78607
87073e6
 
ab78607
87073e6
 
 
 
 
 
 
ab78607
87073e6
 
ab78607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87073e6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import json
import os
from pathlib import Path

import structlog.stdlib
import uvicorn
from asgi_correlation_id import CorrelationIdMiddleware
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import ValidationError
from samgis_core.utilities import create_folders_if_not_exists
from samgis_web.utilities import frontend_builder
from samgis_core.utilities.session_logger import setup_logging
from samgis_web.prediction_api.predictors import samexporter_predict
from samgis_web.utilities.type_hints import ApiRequestBody
from starlette.responses import JSONResponse


load_dotenv()
project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent
workdir = os.getenv("WORKDIR", project_root_folder)
model_folder = Path(project_root_folder / "machine_learning_models")

log_level = os.getenv("LOG_LEVEL", "INFO")
setup_logging(log_level=log_level)
app_logger = structlog.stdlib.get_logger()
app_logger.info(f"PROJECT_ROOT_FOLDER:{project_root_folder}, WORKDIR:{workdir}.")

folders_map = os.getenv("FOLDERS_MAP", "{}")
markdown_text = os.getenv("MARKDOWN_TEXT", "")
examples_text_list = os.getenv("EXAMPLES_TEXT_LIST", "").split("\n")
example_body = json.loads(os.getenv("EXAMPLE_BODY", "{}"))
mount_gradio_app = bool(os.getenv("MOUNT_GRADIO_APP", ""))

static_dist_folder = Path(project_root_folder) / "static" / "dist"
input_css_path = os.getenv("INPUT_CSS_PATH", "src/input.css")
vite_gradio_url = os.getenv("VITE_GRADIO_URL", "/gradio")
vite_index_url = os.getenv("VITE_INDEX_URL", "/")
vite_samgis_url = os.getenv("VITE_SAMGIS_URL", "/samgis")
fastapi_title = "samgis"
app = FastAPI(title=fastapi_title, version="1.0")


@app.middleware("http")
async def request_middleware(request, call_next):
    from samgis_web.web.middlewares import logging_middleware

    return await logging_middleware(request, call_next)


@app.get("/health")
async def health() -> JSONResponse:
    from onnxruntime import __version__ as ort_version
    from samgis_web.__version__ import __version__ as version_web
    from samgis_core.__version__ import __version__ as version_core
    
    from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME

    msg_model_folder_error = f"health_check: model_folder:'{model_folder}' is not a directory."
    msg_model_file_error = f"health_check: model_file:'{model_folder}' not found."
    try:
        assert model_folder.is_dir(), msg_model_folder_error
        encoder_model_path = Path(model_folder) / MODEL_ENCODER_NAME
        decoder_model_path = Path(model_folder) / MODEL_DECODER_NAME
        assert encoder_model_path.is_file(), msg_model_file_error
        assert decoder_model_path.is_file(), msg_model_file_error
        app_logger.info(f"still alive, version_onnxruntime:{ort_version}, version_web:{version_web}, version_core:{version_core}.")
        app_logger.info(f"still alive, encoder_model:{encoder_model_path}, decoder_model:{decoder_model_path}.")
        return JSONResponse(status_code=200, content={"msg": "still alive..."})
    except AssertionError as ae:
        app_logger.error(f"health_check: AssertionError:{ae}.")
        raise HTTPException(500, detail=msg_model_folder_error)


def infer_samgis_fn(request_input: ApiRequestBody | str) -> str | JSONResponse:
    from samgis_web.web.web_helpers import get_parsed_bbox_points_with_dictlist_prompt

    app_logger.info("starting inference request...")
    try:
        import time

        time_start_run = time.time()
        body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
        app_logger.info(f"body_request:{body_request}.")
        try:
            app_logger.info(f"source_name = {body_request['source_name']}.")
            output = samexporter_predict(
                bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
                source=body_request["source"], source_name=body_request['source_name'], model_folder=model_folder
            )
            duration_run = time.time() - time_start_run
            app_logger.info(f"duration_run:{duration_run}.")
            body = {
                "duration_run": duration_run,
                "output": output
            }
            dumped = json.dumps(body)
            app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.")
            app_logger.debug(f"complete json.dumps(body):{dumped}.")
            return dumped
        except Exception as inference_exception:
            app_logger.error(f"inference_exception:{inference_exception}.")
            app_logger.error(f"inference_exception, request_input:{request_input}.")
            raise HTTPException(status_code=500, detail="Internal Server Error")
    except ValidationError as va1:
        app_logger.error(f"validation error: {str(va1)}.")
        app_logger.error(f"ValidationError, request_input:{request_input}.")
        raise RequestValidationError("Unprocessable Entity")


@app.post("/infer_samgis")
def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
    dumped = infer_samgis_fn(request_input=request_input)
    app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.")
    app_logger.debug(f"complete json.dumps(body):{dumped}.")
    return JSONResponse(status_code=200, content={"body": dumped})


@app.exception_handler(RequestValidationError)
def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
    from samgis_web.web import exception_handlers

    return exception_handlers.request_validation_exception_handler(request, exc)


@app.exception_handler(HTTPException)
def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
    from samgis_web.web import exception_handlers

    return exception_handlers.http_exception_handler(request, exc)


create_folders_if_not_exists.folders_creation(folders_map)
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "")
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.")
if bool(write_tmp_on_disk):
    try:
        assert Path(write_tmp_on_disk).is_dir()
        app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output")
        templates = Jinja2Templates(directory=str(project_root_folder / "static"))


        @app.get("/vis_output", response_class=HTMLResponse)
        def list_files(request: Request):

            files = os.listdir(write_tmp_on_disk)
            files_paths = sorted([f"{request.url._url}/{f}" for f in files])
            print(files_paths)
            return templates.TemplateResponse(
                "list_files.html", {"request": request, "files": files_paths}
            )
    except (AssertionError, RuntimeError) as rerr:
        app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...")
        raise rerr

frontend_builder.build_frontend(
    project_root_folder=workdir,
    input_css_path=input_css_path,
    output_dist_folder=static_dist_folder
)
app_logger.info("build_frontend ok!")

# eventually needed for tailwindcss output.css
app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static")
app.mount(vite_index_url, StaticFiles(directory=static_dist_folder, html=True), name="index")
app.mount(vite_gradio_url, StaticFiles(directory=static_dist_folder, html=True), name="gradio")


@app.get(vite_index_url)
async def index() -> FileResponse:
    return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")


app_logger.info(f"Mounted index on url path {vite_index_url} .")
app_logger.info(f"There is need to create and mount gradio app interface? {mount_gradio_app}...")
if mount_gradio_app:
    try:
        import gradio as gr
        from samgis_web.web.gradio_helpers import get_gradio_interface_geojson

        app_logger.info(f"creating gradio interface...")
        gr_interface = get_gradio_interface_geojson(
            infer_samgis_fn,
            markdown_text,
            examples_text_list,
            example_body
        )
        app_logger.info(f"gradio interface created, mounting gradio app on url path {vite_gradio_url} within FastAPI.")
        app_logger.debug(f"gr_interface vars:{vars(gr_interface)}.")
        app = gr.mount_gradio_app(app, gr_interface, path=vite_gradio_url)
        app = gr.mount_gradio_app(app, gr_interface, path="/gradio")
        app_logger.info(f"mounted gradio app within fastapi, url path {vite_gradio_url} .")
    except (ModuleNotFoundError, ImportError) as mnfe:
        app_logger.error("cannot import gradio, have you installed it if you want to mount a gradio app?")
        app_logger.error(mnfe)
        raise mnfe


# add the CorrelationIdMiddleware AFTER the @app.middleware("http") decorated function to avoid missing request id
app.add_middleware(CorrelationIdMiddleware)


if __name__ == '__main__':
    try:
        uvicorn.run("app:app", host="0.0.0.0", port=7860)
    except Exception as ex:
        app_logger.error(f"fastapi/gradio application {fastapi_title}, exception:{ex}.")
        print(f"fastapi/gradio application {fastapi_title}, exception:{ex}.")
        raise ex