[feat] take advantage of re-usable image embeddings in SAM model
Browse files- poetry.lock +0 -0
- pyproject.toml +3 -3
- samgis/io/wrappers_helpers.py +22 -0
- samgis/prediction_api/predictors.py +12 -6
- wrappers/fastapi_wrapper.py +5 -2
poetry.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
CHANGED
@@ -11,15 +11,15 @@ bson = "^0.5.10"
|
|
11 |
contextily = "^1.5.2"
|
12 |
geopandas = "^0.14.3"
|
13 |
loguru = "^0.7.2"
|
14 |
-
numpy = "
|
15 |
onnxruntime = "1.16.3"
|
16 |
opencv-python-headless = "^4.8.1.78"
|
17 |
pillow = "^10.2.0"
|
18 |
-
python = "
|
19 |
python-dotenv = "^1.0.1"
|
20 |
rasterio = "^1.3.9"
|
21 |
requests = "^2.31.0"
|
22 |
-
samgis-core = "^1.
|
23 |
|
24 |
[tool.poetry.group.aws_lambda]
|
25 |
optional = true
|
|
|
11 |
contextily = "^1.5.2"
|
12 |
geopandas = "^0.14.3"
|
13 |
loguru = "^0.7.2"
|
14 |
+
numpy = "~1.25.2"
|
15 |
onnxruntime = "1.16.3"
|
16 |
opencv-python-headless = "^4.8.1.78"
|
17 |
pillow = "^10.2.0"
|
18 |
+
python = "~3.10"
|
19 |
python-dotenv = "^1.0.1"
|
20 |
rasterio = "^1.3.9"
|
21 |
requests = "^2.31.0"
|
22 |
+
samgis-core = "^1.1.1"
|
23 |
|
24 |
[tool.poetry.group.aws_lambda]
|
25 |
optional = true
|
samgis/io/wrappers_helpers.py
CHANGED
@@ -200,3 +200,25 @@ def get_url_tile(source_type: str):
|
|
200 |
|
201 |
def check_source_type_is_terrain(source: str | TileProvider):
|
202 |
return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
def check_source_type_is_terrain(source: str | TileProvider):
|
202 |
return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
|
203 |
+
|
204 |
+
|
205 |
+
def get_source_name(source: str | TileProvider) -> str | bool:
|
206 |
+
try:
|
207 |
+
match source.lower():
|
208 |
+
case XYZDefaultProvidersNames.DEFAULT_TILES_NAME_SHORT:
|
209 |
+
source_output = providers.query_name(XYZDefaultProvidersNames.DEFAULT_TILES_NAME)
|
210 |
+
case _:
|
211 |
+
source_output = providers.query_name(source)
|
212 |
+
if isinstance(source_output, str):
|
213 |
+
return source_output
|
214 |
+
try:
|
215 |
+
source_dict = dict(source_output)
|
216 |
+
app_logger.info(f"source_dict:{type(source_dict)}, {'name' in source_dict}, source_dict:{source_dict}.")
|
217 |
+
return source_dict["name"]
|
218 |
+
except KeyError as ke:
|
219 |
+
app_logger.error(f"ke:{ke}.")
|
220 |
+
except ValueError as ve:
|
221 |
+
app_logger.info(f"source name::{source}, ve:{ve}.")
|
222 |
+
app_logger.info(f"source name::{source}.")
|
223 |
+
|
224 |
+
return False
|
samgis/prediction_api/predictors.py
CHANGED
@@ -6,12 +6,13 @@ from samgis.io.tms2geotiff import download_extent
|
|
6 |
from samgis.io.wrappers_helpers import check_source_type_is_terrain
|
7 |
from samgis.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
|
8 |
from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX
|
9 |
-
from samgis_core.prediction_api.sam_onnx import get_raster_inference
|
10 |
from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
|
11 |
from samgis_core.utilities.type_hints import llist_float, dict_str_int, list_dict
|
12 |
|
13 |
|
14 |
models_dict = {"fastsam": {"instance": None}}
|
|
|
15 |
|
16 |
|
17 |
def samexporter_predict(
|
@@ -19,7 +20,8 @@ def samexporter_predict(
|
|
19 |
prompt: list_dict,
|
20 |
zoom: float,
|
21 |
model_name: str = "fastsam",
|
22 |
-
source: str = DEFAULT_URL_TILES
|
|
|
23 |
) -> dict_str_int:
|
24 |
"""
|
25 |
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
@@ -34,7 +36,8 @@ def samexporter_predict(
|
|
34 |
prompt: machine learning input prompt
|
35 |
zoom: Level of detail
|
36 |
model_name: machine learning model name
|
37 |
-
source: xyz
|
|
|
38 |
|
39 |
Returns:
|
40 |
Affine transform
|
@@ -62,9 +65,12 @@ def samexporter_predict(
|
|
62 |
|
63 |
app_logger.info(
|
64 |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
return {
|
69 |
"n_predictions": n_predictions,
|
70 |
**get_vectorized_raster_as_geojson(mask, transform)
|
|
|
6 |
from samgis.io.wrappers_helpers import check_source_type_is_terrain
|
7 |
from samgis.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
|
8 |
from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX
|
9 |
+
from samgis_core.prediction_api.sam_onnx import get_raster_inference, get_raster_inference_with_embedding_from_dict
|
10 |
from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
|
11 |
from samgis_core.utilities.type_hints import llist_float, dict_str_int, list_dict
|
12 |
|
13 |
|
14 |
models_dict = {"fastsam": {"instance": None}}
|
15 |
+
embedding_dict = {}
|
16 |
|
17 |
|
18 |
def samexporter_predict(
|
|
|
20 |
prompt: list_dict,
|
21 |
zoom: float,
|
22 |
model_name: str = "fastsam",
|
23 |
+
source: str = DEFAULT_URL_TILES,
|
24 |
+
source_name: str = None
|
25 |
) -> dict_str_int:
|
26 |
"""
|
27 |
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
|
|
36 |
prompt: machine learning input prompt
|
37 |
zoom: Level of detail
|
38 |
model_name: machine learning model name
|
39 |
+
source: xyz tile provider object
|
40 |
+
source_name: name of tile provider
|
41 |
|
42 |
Returns:
|
43 |
Affine transform
|
|
|
65 |
|
66 |
app_logger.info(
|
67 |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
68 |
+
app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
|
69 |
+
embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}"
|
70 |
+
mask, n_predictions = get_raster_inference_with_embedding_from_dict(
|
71 |
+
img, prompt, models_instance, model_name, embedding_key, embedding_dict)
|
72 |
+
app_logger.info(f"created {n_predictions} masks, type {type(mask)}, size {mask.size}: preparing geojson conversion")
|
73 |
+
app_logger.info(f"mask shape:{mask.shape}.")
|
74 |
return {
|
75 |
"n_predictions": n_predictions,
|
76 |
**get_vectorized_raster_as_geojson(mask, transform)
|
wrappers/fastapi_wrapper.py
CHANGED
@@ -8,13 +8,14 @@ from fastapi.staticfiles import StaticFiles
|
|
8 |
from pydantic import ValidationError
|
9 |
|
10 |
from samgis import PROJECT_ROOT_FOLDER
|
11 |
-
from samgis.io.wrappers_helpers import get_parsed_bbox_points
|
12 |
from samgis.utilities.type_hints import ApiRequestBody
|
13 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
14 |
from samgis.prediction_api.predictors import samexporter_predict
|
15 |
|
16 |
|
17 |
app_logger = setup_logging(debug=True)
|
|
|
18 |
app = FastAPI()
|
19 |
|
20 |
|
@@ -68,9 +69,11 @@ def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
|
|
68 |
body_request = get_parsed_bbox_points(request_input)
|
69 |
app_logger.info(f"body_request:{body_request}.")
|
70 |
try:
|
|
|
|
|
71 |
output = samexporter_predict(
|
72 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
73 |
-
source=body_request["source"]
|
74 |
)
|
75 |
duration_run = time.time() - time_start_run
|
76 |
app_logger.info(f"duration_run:{duration_run}.")
|
|
|
8 |
from pydantic import ValidationError
|
9 |
|
10 |
from samgis import PROJECT_ROOT_FOLDER
|
11 |
+
from samgis.io.wrappers_helpers import get_parsed_bbox_points, get_source_name
|
12 |
from samgis.utilities.type_hints import ApiRequestBody
|
13 |
from samgis_core.utilities.fastapi_logger import setup_logging
|
14 |
from samgis.prediction_api.predictors import samexporter_predict
|
15 |
|
16 |
|
17 |
app_logger = setup_logging(debug=True)
|
18 |
+
app_logger.info(f"PROJECT_ROOT_FOLDER:{PROJECT_ROOT_FOLDER}.")
|
19 |
app = FastAPI()
|
20 |
|
21 |
|
|
|
69 |
body_request = get_parsed_bbox_points(request_input)
|
70 |
app_logger.info(f"body_request:{body_request}.")
|
71 |
try:
|
72 |
+
source_name = get_source_name(request_input.source_type)
|
73 |
+
app_logger.info(f"source_name = {source_name}.")
|
74 |
output = samexporter_predict(
|
75 |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
76 |
+
source=body_request["source"], source_name=source_name
|
77 |
)
|
78 |
duration_run = time.time() - time_start_run
|
79 |
app_logger.info(f"duration_run:{duration_run}.")
|