alessandro trinca tornidor
commited on
Commit
·
a5e4002
1
Parent(s):
00f8875
refactor: remove unuseful app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU), fix logs, initialize gpu within infer_lisa_gradio()
Browse files- .idea/vcs.xml +1 -0
- app.py +19 -16
- samgis_lisa_on_zero/io_package/wrappers_helpers.py +4 -2
- samgis_lisa_on_zero/prediction_api/lisa.py +15 -7
.idea/vcs.xml
CHANGED
|
@@ -2,5 +2,6 @@
|
|
| 2 |
<project version="4">
|
| 3 |
<component name="VcsDirectoryMappings">
|
| 4 |
<mapping directory="" vcs="Git" />
|
|
|
|
| 5 |
</component>
|
| 6 |
</project>
|
|
|
|
| 2 |
<project version="4">
|
| 3 |
<component name="VcsDirectoryMappings">
|
| 4 |
<mapping directory="" vcs="Git" />
|
| 5 |
+
<mapping directory="$PROJECT_DIR$/sam-quantized" vcs="Git" />
|
| 6 |
</component>
|
| 7 |
</project>
|
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import uuid
|
|
| 5 |
from typing import Callable, NoReturn
|
| 6 |
|
| 7 |
import gradio as gr
|
|
|
|
| 8 |
import uvicorn
|
| 9 |
from fastapi import FastAPI, HTTPException, Request, status
|
| 10 |
from fastapi.exceptions import RequestValidationError
|
|
@@ -13,8 +14,6 @@ from fastapi.staticfiles import StaticFiles
|
|
| 13 |
from fastapi.templating import Jinja2Templates
|
| 14 |
from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists
|
| 15 |
from pydantic import ValidationError
|
| 16 |
-
from spaces import GPU as SPACES_GPU
|
| 17 |
-
|
| 18 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
| 19 |
from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR
|
| 20 |
from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody
|
|
@@ -31,6 +30,11 @@ FASTAPI_TITLE = "samgis-lisa-on-zero"
|
|
| 31 |
app = FastAPI(title=FASTAPI_TITLE, version="1.0")
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def get_gradio_interface_geojson(
|
| 35 |
fn_inference: Callable
|
| 36 |
):
|
|
@@ -143,13 +147,16 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> JSONResponse
|
|
| 143 |
time_start_run = time.time()
|
| 144 |
body_request = get_parsed_bbox_points_with_string_prompt(request_input)
|
| 145 |
app_logger.info(f"lisa body_request:{body_request}.")
|
| 146 |
-
app_logger.info(f"lisa module:{lisa}.")
|
| 147 |
try:
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
output = lisa.lisa_predict(
|
| 151 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
| 152 |
-
source=
|
| 153 |
)
|
| 154 |
duration_run = time.time() - time_start_run
|
| 155 |
app_logger.info(f"duration_run:{duration_run}.")
|
|
@@ -157,9 +164,10 @@ def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> JSONResponse
|
|
| 157 |
"duration_run": duration_run,
|
| 158 |
"output": output
|
| 159 |
}
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
| 163 |
except Exception as inference_exception:
|
| 164 |
handle_exception_response(inference_exception)
|
| 165 |
except ValidationError as va1:
|
|
@@ -187,7 +195,7 @@ def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
|
|
| 187 |
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
|
| 188 |
app_logger.info(f"body_request:{body_request}.")
|
| 189 |
try:
|
| 190 |
-
source_name =
|
| 191 |
app_logger.info(f"source_name = {source_name}.")
|
| 192 |
output = predictors.samexporter_predict(
|
| 193 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
|
@@ -296,7 +304,7 @@ async def lisa() -> FileResponse:
|
|
| 296 |
return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
|
| 297 |
|
| 298 |
|
| 299 |
-
#
|
| 300 |
app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index")
|
| 301 |
|
| 302 |
|
|
@@ -305,12 +313,7 @@ async def index() -> FileResponse:
|
|
| 305 |
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
|
| 306 |
|
| 307 |
|
| 308 |
-
args = app_helpers.parse_args([])
|
| 309 |
-
app_helpers.app_logger.info(f"prepared default arguments:{args}.")
|
| 310 |
-
inference_fn = app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU)
|
| 311 |
-
|
| 312 |
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
|
| 313 |
-
# io_package = app_helpers.get_gradio_interface(inference_fn)
|
| 314 |
io = get_gradio_interface_geojson(infer_lisa_gradio)
|
| 315 |
app_helpers.app_logger.info(
|
| 316 |
f"created gradio interface, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
|
|
|
|
| 5 |
from typing import Callable, NoReturn
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
+
import spaces
|
| 9 |
import uvicorn
|
| 10 |
from fastapi import FastAPI, HTTPException, Request, status
|
| 11 |
from fastapi.exceptions import RequestValidationError
|
|
|
|
| 14 |
from fastapi.templating import Jinja2Templates
|
| 15 |
from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists
|
| 16 |
from pydantic import ValidationError
|
|
|
|
|
|
|
| 17 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
| 18 |
from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR
|
| 19 |
from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody
|
|
|
|
| 30 |
app = FastAPI(title=FASTAPI_TITLE, version="1.0")
|
| 31 |
|
| 32 |
|
| 33 |
+
@spaces.GPU
|
| 34 |
+
def gpu_initialization() -> None:
|
| 35 |
+
app_logger.info("GPU initialization...")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
def get_gradio_interface_geojson(
|
| 39 |
fn_inference: Callable
|
| 40 |
):
|
|
|
|
| 147 |
time_start_run = time.time()
|
| 148 |
body_request = get_parsed_bbox_points_with_string_prompt(request_input)
|
| 149 |
app_logger.info(f"lisa body_request:{body_request}.")
|
|
|
|
| 150 |
try:
|
| 151 |
+
source = body_request["source"]
|
| 152 |
+
source_name = body_request["source_name"]
|
| 153 |
+
app_logger.debug(f"body_request:type(source):{type(source)}, source:{source}.")
|
| 154 |
+
app_logger.debug(f"body_request:type(source_name):{type(source_name)}, source_name:{source_name}.")
|
| 155 |
+
app_logger.debug(f"lisa module:{lisa}.")
|
| 156 |
+
gpu_initialization()
|
| 157 |
output = lisa.lisa_predict(
|
| 158 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
| 159 |
+
source=source, source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN
|
| 160 |
)
|
| 161 |
duration_run = time.time() - time_start_run
|
| 162 |
app_logger.info(f"duration_run:{duration_run}.")
|
|
|
|
| 164 |
"duration_run": duration_run,
|
| 165 |
"output": output
|
| 166 |
}
|
| 167 |
+
dumped = json.dumps(body)
|
| 168 |
+
app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.")
|
| 169 |
+
app_logger.debug(f"complete json.dumps(body):{dumped}.")
|
| 170 |
+
return dumped
|
| 171 |
except Exception as inference_exception:
|
| 172 |
handle_exception_response(inference_exception)
|
| 173 |
except ValidationError as va1:
|
|
|
|
| 195 |
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
|
| 196 |
app_logger.info(f"body_request:{body_request}.")
|
| 197 |
try:
|
| 198 |
+
source_name = body_request["source_name"]
|
| 199 |
app_logger.info(f"source_name = {source_name}.")
|
| 200 |
output = predictors.samexporter_predict(
|
| 201 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
|
|
|
| 304 |
return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html")
|
| 305 |
|
| 306 |
|
| 307 |
+
# index.html (lisa.html copy)
|
| 308 |
app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index")
|
| 309 |
|
| 310 |
|
|
|
|
| 313 |
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html")
|
| 314 |
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
|
|
|
|
| 317 |
io = get_gradio_interface_geojson(infer_lisa_gradio)
|
| 318 |
app_helpers.app_logger.info(
|
| 319 |
f"created gradio interface, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...")
|
samgis_lisa_on_zero/io_package/wrappers_helpers.py
CHANGED
|
@@ -83,7 +83,8 @@ def get_parsed_bbox_points_with_string_prompt(request_input: StringPromptApiRequ
|
|
| 83 |
"bbox": [ne_latlng, sw_latlng],
|
| 84 |
"prompt": cleaned_prompt,
|
| 85 |
"zoom": new_zoom,
|
| 86 |
-
"source": get_url_tile(request_input.source_type)
|
|
|
|
| 87 |
}
|
| 88 |
|
| 89 |
|
|
@@ -119,7 +120,8 @@ def get_parsed_bbox_points_with_dictlist_prompt(request_input: ApiRequestBody) -
|
|
| 119 |
"bbox": [ne_latlng, sw_latlng],
|
| 120 |
"prompt": new_prompt_list,
|
| 121 |
"zoom": new_zoom,
|
| 122 |
-
"source": get_url_tile(request_input.source_type)
|
|
|
|
| 123 |
}
|
| 124 |
|
| 125 |
|
|
|
|
| 83 |
"bbox": [ne_latlng, sw_latlng],
|
| 84 |
"prompt": cleaned_prompt,
|
| 85 |
"zoom": new_zoom,
|
| 86 |
+
"source": get_url_tile(request_input.source_type),
|
| 87 |
+
"source_name": get_source_name(request_input.source_type)
|
| 88 |
}
|
| 89 |
|
| 90 |
|
|
|
|
| 120 |
"bbox": [ne_latlng, sw_latlng],
|
| 121 |
"prompt": new_prompt_list,
|
| 122 |
"zoom": new_zoom,
|
| 123 |
+
"source": get_url_tile(request_input.source_type),
|
| 124 |
+
"source_name": get_source_name(request_input.source_type)
|
| 125 |
}
|
| 126 |
|
| 127 |
|
samgis_lisa_on_zero/prediction_api/lisa.py
CHANGED
|
@@ -16,7 +16,9 @@ def load_model_and_inference_fn(inference_function_name_key: str):
|
|
| 16 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
| 17 |
|
| 18 |
if models_dict[inference_function_name_key]["inference"] is None:
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
parsed_args = app_helpers.parse_args([])
|
| 21 |
inference_fn = app_helpers.get_inference_model_by_args(
|
| 22 |
parsed_args,
|
|
@@ -57,10 +59,17 @@ def lisa_predict(
|
|
| 57 |
from samgis_lisa_on_zero import app_logger
|
| 58 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
| 59 |
|
|
|
|
|
|
|
|
|
|
| 60 |
app_logger.info("start lisa inference...")
|
|
|
|
|
|
|
|
|
|
| 61 |
load_model_and_inference_fn(inference_function_name_key)
|
| 62 |
-
app_logger.debug(f"using a {inference_function_name_key} instance model...")
|
| 63 |
inference_fn = models_dict[inference_function_name_key]["inference"]
|
|
|
|
| 64 |
|
| 65 |
pt0, pt1 = bbox
|
| 66 |
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
|
@@ -80,15 +89,14 @@ def lisa_predict(
|
|
| 80 |
app_logger.info("keep all temp data in memory...")
|
| 81 |
|
| 82 |
app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
|
| 83 |
-
app_logger.info(f"lisa_zero, prompt
|
| 84 |
app_logger.info(f"lisa_zero, prompt:{prompt}.")
|
| 85 |
prompt_str = str(prompt)
|
| 86 |
-
app_logger.info(f"lisa_zero, img
|
| 87 |
embedding_key = f"{source_name}_z{zoom}_{prefix}"
|
| 88 |
_, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
|
| 89 |
-
app_logger.info(f"lisa_zero, output_string
|
| 90 |
-
app_logger.info(f"lisa_zero,
|
| 91 |
-
app_logger.info(f"lisa_zero, mask_output tpye:{type(mask)}.")
|
| 92 |
app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
|
| 93 |
return {
|
| 94 |
"output_string": output_string,
|
|
|
|
| 16 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
| 17 |
|
| 18 |
if models_dict[inference_function_name_key]["inference"] is None:
|
| 19 |
+
msg = f"missing inference function {inference_function_name_key}, "
|
| 20 |
+
msg += f"instantiating it now using inference_decorator {SPACES_GPU}!"
|
| 21 |
+
app_logger.info(msg)
|
| 22 |
parsed_args = app_helpers.parse_args([])
|
| 23 |
inference_fn = app_helpers.get_inference_model_by_args(
|
| 24 |
parsed_args,
|
|
|
|
| 59 |
from samgis_lisa_on_zero import app_logger
|
| 60 |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict
|
| 61 |
|
| 62 |
+
if source_name is None:
|
| 63 |
+
source_name = str(source)
|
| 64 |
+
|
| 65 |
app_logger.info("start lisa inference...")
|
| 66 |
+
app_logger.debug(f"type(source):{type(source)}, source:{source},")
|
| 67 |
+
app_logger.debug(f"type(source_name):{type(source_name)}, source_name:{source_name}.")
|
| 68 |
+
|
| 69 |
load_model_and_inference_fn(inference_function_name_key)
|
| 70 |
+
app_logger.debug(f"using a '{inference_function_name_key}' instance model...")
|
| 71 |
inference_fn = models_dict[inference_function_name_key]["inference"]
|
| 72 |
+
app_logger.info(f"loaded inference function '{inference_fn.__name__}'.")
|
| 73 |
|
| 74 |
pt0, pt1 = bbox
|
| 75 |
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
|
|
|
| 89 |
app_logger.info("keep all temp data in memory...")
|
| 90 |
|
| 91 |
app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.")
|
| 92 |
+
app_logger.info(f"lisa_zero, prompt type:{type(prompt)}.")
|
| 93 |
app_logger.info(f"lisa_zero, prompt:{prompt}.")
|
| 94 |
prompt_str = str(prompt)
|
| 95 |
+
app_logger.info(f"lisa_zero, img type:{type(img)}.")
|
| 96 |
embedding_key = f"{source_name}_z{zoom}_{prefix}"
|
| 97 |
_, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key)
|
| 98 |
+
app_logger.info(f"lisa_zero, output_string type:{type(output_string)}.")
|
| 99 |
+
app_logger.info(f"lisa_zero, mask_output type:{type(mask)}.")
|
|
|
|
| 100 |
app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...")
|
| 101 |
return {
|
| 102 |
"output_string": output_string,
|