aletrn commited on
Commit
cecaec0
·
1 Parent(s): 031719b

[feat] take advantage of re-usable image embeddings in SAM model

Browse files
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -11,15 +11,15 @@ bson = "^0.5.10"
11
  contextily = "^1.5.2"
12
  geopandas = "^0.14.3"
13
  loguru = "^0.7.2"
14
- numpy = "^1.26.4"
15
  onnxruntime = "1.16.3"
16
  opencv-python-headless = "^4.8.1.78"
17
  pillow = "^10.2.0"
18
- python = "^3.11"
19
  python-dotenv = "^1.0.1"
20
  rasterio = "^1.3.9"
21
  requests = "^2.31.0"
22
- samgis-core = "^1.0.3"
23
 
24
  [tool.poetry.group.aws_lambda]
25
  optional = true
 
11
  contextily = "^1.5.2"
12
  geopandas = "^0.14.3"
13
  loguru = "^0.7.2"
14
+ numpy = "~1.25.2"
15
  onnxruntime = "1.16.3"
16
  opencv-python-headless = "^4.8.1.78"
17
  pillow = "^10.2.0"
18
+ python = "~3.10"
19
  python-dotenv = "^1.0.1"
20
  rasterio = "^1.3.9"
21
  requests = "^2.31.0"
22
+ samgis-core = "^1.1.1"
23
 
24
  [tool.poetry.group.aws_lambda]
25
  optional = true
samgis/io/wrappers_helpers.py CHANGED
@@ -200,3 +200,25 @@ def get_url_tile(source_type: str):
200
 
201
  def check_source_type_is_terrain(source: str | TileProvider):
202
  return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  def check_source_type_is_terrain(source: str | TileProvider):
202
  return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
203
+
204
+
205
+ def get_source_name(source: str | TileProvider) -> str | bool:
206
+ try:
207
+ match source.lower():
208
+ case XYZDefaultProvidersNames.DEFAULT_TILES_NAME_SHORT:
209
+ source_output = providers.query_name(XYZDefaultProvidersNames.DEFAULT_TILES_NAME)
210
+ case _:
211
+ source_output = providers.query_name(source)
212
+ if isinstance(source_output, str):
213
+ return source_output
214
+ try:
215
+ source_dict = dict(source_output)
216
+ app_logger.info(f"source_dict:{type(source_dict)}, {'name' in source_dict}, source_dict:{source_dict}.")
217
+ return source_dict["name"]
218
+ except KeyError as ke:
219
+ app_logger.error(f"ke:{ke}.")
220
+ except ValueError as ve:
221
+ app_logger.info(f"source name::{source}, ve:{ve}.")
222
+ app_logger.info(f"source name::{source}.")
223
+
224
+ return False
samgis/prediction_api/predictors.py CHANGED
@@ -6,12 +6,13 @@ from samgis.io.tms2geotiff import download_extent
6
  from samgis.io.wrappers_helpers import check_source_type_is_terrain
7
  from samgis.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
8
  from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX
9
- from samgis_core.prediction_api.sam_onnx import get_raster_inference
10
  from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
11
  from samgis_core.utilities.type_hints import llist_float, dict_str_int, list_dict
12
 
13
 
14
  models_dict = {"fastsam": {"instance": None}}
 
15
 
16
 
17
  def samexporter_predict(
@@ -19,7 +20,8 @@ def samexporter_predict(
19
  prompt: list_dict,
20
  zoom: float,
21
  model_name: str = "fastsam",
22
- source: str = DEFAULT_URL_TILES
 
23
  ) -> dict_str_int:
24
  """
25
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
@@ -34,7 +36,8 @@ def samexporter_predict(
34
  prompt: machine learning input prompt
35
  zoom: Level of detail
36
  model_name: machine learning model name
37
- source: xyz
 
38
 
39
  Returns:
40
  Affine transform
@@ -62,9 +65,12 @@ def samexporter_predict(
62
 
63
  app_logger.info(
64
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
65
-
66
- mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
67
- app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
 
 
 
68
  return {
69
  "n_predictions": n_predictions,
70
  **get_vectorized_raster_as_geojson(mask, transform)
 
6
  from samgis.io.wrappers_helpers import check_source_type_is_terrain
7
  from samgis.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
8
  from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX
9
+ from samgis_core.prediction_api.sam_onnx import get_raster_inference, get_raster_inference_with_embedding_from_dict
10
  from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
11
  from samgis_core.utilities.type_hints import llist_float, dict_str_int, list_dict
12
 
13
 
14
  models_dict = {"fastsam": {"instance": None}}
15
+ embedding_dict = {}
16
 
17
 
18
  def samexporter_predict(
 
20
  prompt: list_dict,
21
  zoom: float,
22
  model_name: str = "fastsam",
23
+ source: str = DEFAULT_URL_TILES,
24
+ source_name: str = None
25
  ) -> dict_str_int:
26
  """
27
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
 
36
  prompt: machine learning input prompt
37
  zoom: Level of detail
38
  model_name: machine learning model name
39
+ source: xyz tile provider object
40
+ source_name: name of tile provider
41
 
42
  Returns:
43
  Affine transform
 
65
 
66
  app_logger.info(
67
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
68
+ app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
69
+ embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}"
70
+ mask, n_predictions = get_raster_inference_with_embedding_from_dict(
71
+ img, prompt, models_instance, model_name, embedding_key, embedding_dict)
72
+ app_logger.info(f"created {n_predictions} masks, type {type(mask)}, size {mask.size}: preparing geojson conversion")
73
+ app_logger.info(f"mask shape:{mask.shape}.")
74
  return {
75
  "n_predictions": n_predictions,
76
  **get_vectorized_raster_as_geojson(mask, transform)
wrappers/fastapi_wrapper.py CHANGED
@@ -8,13 +8,14 @@ from fastapi.staticfiles import StaticFiles
8
  from pydantic import ValidationError
9
 
10
  from samgis import PROJECT_ROOT_FOLDER
11
- from samgis.io.wrappers_helpers import get_parsed_bbox_points
12
  from samgis.utilities.type_hints import ApiRequestBody
13
  from samgis_core.utilities.fastapi_logger import setup_logging
14
  from samgis.prediction_api.predictors import samexporter_predict
15
 
16
 
17
  app_logger = setup_logging(debug=True)
 
18
  app = FastAPI()
19
 
20
 
@@ -68,9 +69,11 @@ def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
68
  body_request = get_parsed_bbox_points(request_input)
69
  app_logger.info(f"body_request:{body_request}.")
70
  try:
 
 
71
  output = samexporter_predict(
72
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
73
- source=body_request["source"]
74
  )
75
  duration_run = time.time() - time_start_run
76
  app_logger.info(f"duration_run:{duration_run}.")
 
8
  from pydantic import ValidationError
9
 
10
  from samgis import PROJECT_ROOT_FOLDER
11
+ from samgis.io.wrappers_helpers import get_parsed_bbox_points, get_source_name
12
  from samgis.utilities.type_hints import ApiRequestBody
13
  from samgis_core.utilities.fastapi_logger import setup_logging
14
  from samgis.prediction_api.predictors import samexporter_predict
15
 
16
 
17
  app_logger = setup_logging(debug=True)
18
+ app_logger.info(f"PROJECT_ROOT_FOLDER:{PROJECT_ROOT_FOLDER}.")
19
  app = FastAPI()
20
 
21
 
 
69
  body_request = get_parsed_bbox_points(request_input)
70
  app_logger.info(f"body_request:{body_request}.")
71
  try:
72
+ source_name = get_source_name(request_input.source_type)
73
+ app_logger.info(f"source_name = {source_name}.")
74
  output = samexporter_predict(
75
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
76
+ source=body_request["source"], source_name=source_name
77
  )
78
  duration_run = time.time() - time_start_run
79
  app_logger.info(f"duration_run:{duration_run}.")