aletrn commited on
Commit
5b88544
·
1 Parent(s): b241742

[refactor] remove and transform unuseful logs to debug

Browse files
src/app.py CHANGED
@@ -1,17 +1,15 @@
1
  import json
2
  import time
3
  from http import HTTPStatus
4
- from pathlib import Path
5
  from typing import Dict
6
 
7
  from aws_lambda_powertools.event_handler import content_types
8
  from aws_lambda_powertools.utilities.typing import LambdaContext
9
 
10
  from src import app_logger
11
- from src.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates, get_point_latlng_to_pixel_coordinates
12
  from src.prediction_api.predictors import samexporter_predict
13
- from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES, ROOT
14
- from src.utilities.serialize import serialize
15
  from src.utilities.utilities import base64_decode
16
 
17
 
@@ -59,32 +57,40 @@ def get_parsed_bbox_points(request_input: Dict) -> Dict:
59
  for prompt in request_input["prompt"]:
60
  app_logger.info(f"current prompt: {type(prompt)}, value:{prompt}.")
61
  data = prompt["data"]
62
- if prompt["type"] == "rectangle":
63
- app_logger.info(f"current data points: {type(data)}, value:{data}.")
64
- data_ne = data["ne"]
65
- app_logger.info(f"current data_ne point: {type(data_ne)}, value:{data_ne}.")
66
- data_sw = data["sw"]
67
- app_logger.info(f"current data_sw point: {type(data_sw)}, value:{data_sw}.")
68
-
69
- diff_pixel_coords_origin_data_ne = get_latlng_to_pixel_coordinates(ne, sw, data_ne, zoom, "ne")
70
- app_logger.info(f'current diff prompt ne: {type(data)}, {data} => {diff_pixel_coords_origin_data_ne}.')
71
- diff_pixel_coords_origin_data_sw = get_latlng_to_pixel_coordinates(ne, sw, data_sw, zoom, "sw")
72
- app_logger.info(f'current diff prompt sw: {type(data)}, {data} => {diff_pixel_coords_origin_data_sw}.')
73
- prompt["data"] = [
74
- diff_pixel_coords_origin_data_ne["x"], diff_pixel_coords_origin_data_ne["y"],
75
- diff_pixel_coords_origin_data_sw["x"], diff_pixel_coords_origin_data_sw["y"]
76
- ]
77
- elif prompt["type"] == "point":
 
 
 
 
 
 
 
 
78
  current_point = get_latlng_to_pixel_coordinates(ne, sw, data, zoom, "point")
79
  app_logger.info(f"current prompt: {type(current_point)}, value:{current_point}.")
80
  new_prompt_data = [current_point['x'], current_point['y']]
81
  app_logger.info(f"new_prompt_data: {type(new_prompt_data)}, value:{new_prompt_data}.")
82
  prompt["data"] = new_prompt_data
83
  else:
84
- raise ValueError("valid prompt types are only 'point' and 'rectangle'")
85
 
86
  app_logger.info(f"bbox => {bbox}.")
87
- app_logger.info(f'## request_input["prompt"] updated => {request_input["prompt"]}.')
88
 
89
  app_logger.info(f"unpacking elaborated {request_input}...")
90
  return {
@@ -125,9 +131,7 @@ def lambda_handler(event: dict, context: LambdaContext):
125
  app_logger.info(f"prompt_latlng:{prompt_latlng}.")
126
  body_request = get_parsed_bbox_points(body)
127
  app_logger.info(f"body_request=> {type(body_request)}, {body_request}.")
128
- body_response = samexporter_predict(
129
- body_request["bbox"], body_request["prompt"], body_request["zoom"], prompt_latlng
130
- )
131
  app_logger.info(f"output body_response:{body_response}.")
132
  response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
133
  except Exception as ex2:
 
1
  import json
2
  import time
3
  from http import HTTPStatus
 
4
  from typing import Dict
5
 
6
  from aws_lambda_powertools.event_handler import content_types
7
  from aws_lambda_powertools.utilities.typing import LambdaContext
8
 
9
  from src import app_logger
10
+ from src.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates
11
  from src.prediction_api.predictors import samexporter_predict
12
+ from src.utilities.constants import CUSTOM_RESPONSE_MESSAGES
 
13
  from src.utilities.utilities import base64_decode
14
 
15
 
 
57
  for prompt in request_input["prompt"]:
58
  app_logger.info(f"current prompt: {type(prompt)}, value:{prompt}.")
59
  data = prompt["data"]
60
+ # if prompt["type"] == "rectangle":
61
+ # app_logger.info(f"current data points: {type(data)}, value:{data}.")
62
+ # data_ne = data["ne"]
63
+ # app_logger.info(f"current data_ne point: {type(data_ne)}, value:{data_ne}.")
64
+ # data_sw = data["sw"]
65
+ # app_logger.info(f"current data_sw point: {type(data_sw)}, value:{data_sw}.")
66
+ #
67
+ # diff_pixel_coords_origin_data_ne = get_latlng_to_pixel_coordinates(ne, sw, data_ne, zoom, "ne")
68
+ # app_logger.info(f'current diff prompt ne: {type(data)}, {data} => {diff_pixel_coords_origin_data_ne}.')
69
+ # diff_pixel_coords_origin_data_sw = get_latlng_to_pixel_coordinates(ne, sw, data_sw, zoom, "sw")
70
+ # app_logger.info(f'current diff prompt sw: {type(data)}, {data} => {diff_pixel_coords_origin_data_sw}.')
71
+ # prompt["data"] = [
72
+ # diff_pixel_coords_origin_data_ne["x"], diff_pixel_coords_origin_data_ne["y"],
73
+ # diff_pixel_coords_origin_data_sw["x"], diff_pixel_coords_origin_data_sw["y"]
74
+ # ]
75
+ # # rect_diffs_input = str(Path(ROOT) / "rect_diffs_input.json")
76
+ # # with open(rect_diffs_input, "w") as jj_out3:
77
+ # # json.dump({
78
+ # # "prompt_data": serialize(prompt["data"]),
79
+ # # "diff_pixel_coords_origin_data_ne": serialize(diff_pixel_coords_origin_data_ne),
80
+ # # "diff_pixel_coords_origin_data_sw": serialize(diff_pixel_coords_origin_data_sw),
81
+ # # }, jj_out3)
82
+ # # app_logger.info(f"written json:{rect_diffs_input}.")
83
+ if prompt["type"] == "point":
84
  current_point = get_latlng_to_pixel_coordinates(ne, sw, data, zoom, "point")
85
  app_logger.info(f"current prompt: {type(current_point)}, value:{current_point}.")
86
  new_prompt_data = [current_point['x'], current_point['y']]
87
  app_logger.info(f"new_prompt_data: {type(new_prompt_data)}, value:{new_prompt_data}.")
88
  prompt["data"] = new_prompt_data
89
  else:
90
+ raise ValueError("valid prompt type is only 'point'")
91
 
92
  app_logger.info(f"bbox => {bbox}.")
93
+ app_logger.info(f'## request_input-prompt updated => {request_input["prompt"]}.')
94
 
95
  app_logger.info(f"unpacking elaborated {request_input}...")
96
  return {
 
131
  app_logger.info(f"prompt_latlng:{prompt_latlng}.")
132
  body_request = get_parsed_bbox_points(body)
133
  app_logger.info(f"body_request=> {type(body_request)}, {body_request}.")
134
+ body_response = samexporter_predict(body_request["bbox"], body_request["prompt"], body_request["zoom"])
 
 
135
  app_logger.info(f"output body_response:{body_response}.")
136
  response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
137
  except Exception as ex2:
src/io/tiles_to_tiff.py CHANGED
@@ -1,15 +1,14 @@
1
  """Async download raster tiles"""
2
- import os
3
  from pathlib import Path
4
 
5
  import numpy as np
6
 
7
  from src import app_logger, PROJECT_ROOT_FOLDER
8
- from src.io.helpers import get_lat_lon_coords, merge_tiles, get_geojson_square_angles, crop_raster
9
- from src.io.tms2geotiff import download_extent, save_geotiff_gdal
10
  from src.utilities.constants import COMPLETE_URL_TILES, DEFAULT_TMS
11
  from src.utilities.type_hints import ts_llist2
12
 
 
13
  COOKIE_SESSION = {
14
  "Accept": "*/*",
15
  "Accept-Encoding": "gzip, deflate",
@@ -58,137 +57,34 @@ def convert(bounding_box: ts_llist2, zoom: int) -> tuple:
58
  dict: uploaded_file_name (str), bucket_name (str), prediction_georef (dict), n_total_obj_prediction (str)
59
 
60
  """
61
- import tempfile
62
-
63
- # from src.surferdtm_prediction_api.utilities.constants import NODATA_VALUES
64
- # from src.surferdtm_prediction_api.utilities.utilities import setup_logging
65
- # from src.surferdtm_prediction_api.raster.elaborate_images import elaborate_images.get_rgb_prediction_image
66
- # from src.surferdtm_prediction_api.raster.prediction import model_prediction
67
- # from src.surferdtm_prediction_api.geo.helpers import get_lat_lon_coords, merge_tiles, crop_raster, get_prediction_georeferenced, \
68
- # get_geojson_square_angles, get_perc
69
-
70
- # app_logger = setup_logging(debug)
71
- ext = "tif"
72
- debug = False
73
  tile_source = COMPLETE_URL_TILES
74
  app_logger.info(f"start_args: tile_source:{tile_source},bounding_box:{bounding_box},zoom:{zoom}.")
75
 
76
  try:
77
  import rasterio
78
 
79
- lon_min, lat_min, lon_max, lat_max = get_lat_lon_coords(bounding_box)
80
-
81
- with tempfile.TemporaryDirectory() as input_tmp_dir:
82
- # with tempfile.TemporaryDirectory() as output_tmp_dir:
83
- output_tmp_dir = input_tmp_dir
84
- app_logger.info(f'tile_source: {tile_source}!')
85
- app_logger.info(f'created temporary input/output directory: {input_tmp_dir} => {output_tmp_dir}!')
86
- pt0, pt1 = bounding_box
87
- app_logger.info("downloading...")
88
- img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
89
-
90
- app_logger.info(f'img: type {type(img)}, len_matrix:{len(matrix)}, matrix {matrix}.')
91
- app_logger.info(f'img: size (shape if PIL) {img.size}.')
92
- try:
93
- np_img = np.array(img)
94
- app_logger.info(f'img: shape (numpy) {np_img.shape}.')
95
- except Exception as e_shape:
96
- app_logger.info(f'e_shape {e_shape}.')
97
- raise e_shape
98
- img.save(f"/tmp/downloaded_{pt0[0]}_{pt0[1]}_{pt1[0]}_{pt1[1]}.png")
99
- app_logger.info("saved PIL image")
100
-
101
- return img, matrix
102
- # app_logger.info("prepare writing...")
103
- # app_logger.info(f'img: type {type(img)}, len_matrix:{len(matrix)}, matrix {matrix}.')
104
- #
105
- # rio_output = str(Path(output_tmp_dir) / "downloaded_rio.tif")
106
- # app_logger.info(f'writing to disk img, output file {rio_output}.')
107
- # save_geotiff_gdal(img, rio_output, matrix)
108
- # app_logger.info(f'img written to output file {rio_output}.')
109
- #
110
- # source_tiles = os.path.join(input_tmp_dir, f"*.{ext}")
111
- # suffix_raster_filename = f"{lon_min},{lat_min},{lon_max},{lat_max}_{zoom}"
112
- # merged_raster_filename = f"merged_{suffix_raster_filename}.{ext}"
113
- # masked_raster_filename = f"masked_{suffix_raster_filename}.{ext}"
114
- # output_merged_path = os.path.join(output_tmp_dir, merged_raster_filename)
115
- #
116
- # app_logger.info(f"try merging tiles to:{output_merged_path}.")
117
- # merge_tiles(source_tiles, output_merged_path, input_tmp_dir)
118
- # app_logger.info(f"Merge complete, try crop...")
119
- # geojson = get_geojson_square_angles(bounding_box, name=suffix_raster_filename, debug=debug)
120
- # app_logger.info(f"geojson to convert:{geojson}.")
121
- #
122
- # crop_raster_output = crop_raster(output_merged_path, geojson, debug=False)
123
- # masked_raster = crop_raster_output["masked_raster"]
124
- # masked_meta = crop_raster_output["masked_meta"]
125
- # masked_transform = crop_raster_output["masked_transform"]
126
- #
127
- # return masked_raster, masked_transform
128
-
129
- # app_logger.info(f"resampling -32768 values as NaN for file:{masked_raster_filename}.")
130
- # masked_raster = masked_raster[0].astype(float)
131
- # masked_raster[masked_raster == NODATA_VALUES] = 0
132
- # # info
133
- # nan_count = np.count_nonzero(~np.isnan(masked_raster))
134
- # total_count = masked_raster.shape[-1] * masked_raster.shape[-2]
135
- # perc = get_perc(nan_count, total_count)
136
- # msg = f"img:{masked_raster_filename}, shape:{masked_raster.shape}: found {nan_count} not-NaN values / {total_count} total, %:{perc}."
137
- # app_logger.info(msg)
138
- #
139
- # app_logger.info(f"crop complete, shape:{masked_raster.shape}, dtype:{masked_raster.dtype}. Create RGB image...")
140
- # # rgb_filename, rgb_path = elaborate_images.get_rgb_prediction_image(masked_raster, slope_cellsize, suffix_raster_filename, output_tmp_dir, debug=debug)
141
- # # prediction = model_prediction(rgb_path, project_name=model_project_name, version=model_version, api_key=model_api_key, debug=False)
142
- #
143
- # mask_vectorizing = np.ones(masked_raster.shape).astype(rasterio.uint8)
144
- # app_logger.info(f"prediction success, try to geo-referencing it with transform:{masked_transform}.")
145
- #
146
- # app_logger.info(
147
- # f"image/geojson origin matrix:, masked_transform:{masked_transform}: create shapes_generator...")
148
- # app_logger.info(f"raster mask to vectorize, type:{type(mask_vectorizing)}.")
149
- # app_logger.info(f"raster mask to vectorize: shape:{mask_vectorizing.shape}, {mask_vectorizing.dtype}.")
150
- #
151
- # shapes_generator = ({
152
- # 'properties': {'raster_val': v}, 'geometry': s}
153
- # for i, (s, v)
154
- # in enumerate(shapes(mask_vectorizing, mask=mask_vectorizing, transform=masked_transform))
155
- # )
156
- # shapes_list = list(shapes_generator)
157
- # app_logger.info(f"created {len(shapes_list)} polygons.")
158
- # gpd_polygonized_raster = GeoDataFrame.from_features(shapes_list, crs="EPSG:3857")
159
- # app_logger.info(f"created a GeoDataFrame: type {type(gpd_polygonized_raster)}.")
160
- # geojson = gpd_polygonized_raster.to_json(to_wgs84=True)
161
- # app_logger.info(f"created geojson: type {type(geojson)}, len:{len(geojson)}.")
162
- # serialized_geojson = serialize.serialize(geojson)
163
- # app_logger.info(f"created serialized_geojson: type {type(serialized_geojson)}, len:{len(serialized_geojson)}.")
164
- # loaded_geojson = json.loads(geojson)
165
- # app_logger.info(f"loaded_geojson: type {type(loaded_geojson)}, loaded_geojson:{loaded_geojson}.")
166
- # n_feats = len(loaded_geojson["features"])
167
- # app_logger.info(f"created geojson: n_feats {n_feats}.")
168
- #
169
- # output_geojson = str(Path(ROOT) / "geojson_output.json")
170
- # with open(output_geojson, "w") as jj_out:
171
- # app_logger.info(f"writing geojson file to {output_geojson}.")
172
- # json.dump(loaded_geojson, jj_out)
173
- # app_logger.info(f"geojson file written to {output_geojson}.")
174
- #
175
- # # prediction_georef = helpers.get_prediction_georeferenced(prediction, masked_transform, skip_conditions_list, debug=debug)
176
- # app_logger.info(f"success on geo-referencing prediction.")
177
- # # app_logger.info(f"success on creating file {rgb_filename}, now try upload it to bucket_name {bucket_name}...")
178
- # return {
179
- # # "uploaded_file_name": rgb_filename,
180
- # "geojson": loaded_geojson,
181
- # # "prediction_georef": prediction_georef,
182
- # "n_total_obj_prediction": n_feats
183
- # }
184
  except ImportError as e_import_convert:
185
  app_logger.error(f"e0:{e_import_convert}.")
186
  raise e_import_convert
187
 
188
 
189
  if __name__ == '__main__':
190
- from PIL import Image
191
-
192
  npy_file = "prediction_masks_46.27697017893455_9.616470336914064_46.11441972281433_9.264907836914064.npy"
193
  prediction_masks = np.load(Path(PROJECT_ROOT_FOLDER) / "tmp" / "try_by_steps" / "t0" / npy_file)
194
 
 
1
  """Async download raster tiles"""
 
2
  from pathlib import Path
3
 
4
  import numpy as np
5
 
6
  from src import app_logger, PROJECT_ROOT_FOLDER
7
+ from src.io.tms2geotiff import download_extent
 
8
  from src.utilities.constants import COMPLETE_URL_TILES, DEFAULT_TMS
9
  from src.utilities.type_hints import ts_llist2
10
 
11
+
12
  COOKIE_SESSION = {
13
  "Accept": "*/*",
14
  "Accept-Encoding": "gzip, deflate",
 
57
  dict: uploaded_file_name (str), bucket_name (str), prediction_georef (dict), n_total_obj_prediction (str)
58
 
59
  """
60
+
 
 
 
 
 
 
 
 
 
 
 
61
  tile_source = COMPLETE_URL_TILES
62
  app_logger.info(f"start_args: tile_source:{tile_source},bounding_box:{bounding_box},zoom:{zoom}.")
63
 
64
  try:
65
  import rasterio
66
 
67
+ app_logger.info(f'tile_source: {tile_source}!')
68
+ pt0, pt1 = bounding_box
69
+ app_logger.info("downloading...")
70
+ img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
71
+
72
+ app_logger.info(f'img: type {type(img)}, len_matrix:{len(matrix)}, matrix {matrix}.')
73
+ app_logger.info(f'img: size (shape if PIL) {img.size}.')
74
+ try:
75
+ np_img = np.array(img)
76
+ app_logger.info(f'img: shape (numpy) {np_img.shape}.')
77
+ except Exception as e_shape:
78
+ app_logger.info(f'e_shape {e_shape}.')
79
+ raise e_shape
80
+
81
+ return img, matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except ImportError as e_import_convert:
83
  app_logger.error(f"e0:{e_import_convert}.")
84
  raise e_import_convert
85
 
86
 
87
  if __name__ == '__main__':
 
 
88
  npy_file = "prediction_masks_46.27697017893455_9.616470336914064_46.11441972281433_9.264907836914064.npy"
89
  prediction_masks = np.load(Path(PROJECT_ROOT_FOLDER) / "tmp" / "try_by_steps" / "t0" / npy_file)
90
 
src/main.py DELETED
@@ -1,14 +0,0 @@
1
- import rasterio
2
- from affine import loadsw
3
-
4
- from src import PROJECT_ROOT_FOLDER
5
-
6
- if __name__ == '__main__':
7
- with open(PROJECT_ROOT_FOLDER / "tmp" / "japan_out_main.pgw") as pgw:
8
- pgw_file = pgw.read()
9
- a = loadsw(pgw_file)
10
- with rasterio.open(PROJECT_ROOT_FOLDER / "tmp" / "japan_out_main.png", "r") as src:
11
- src_transform = src.transform
12
- print(a, src_transform)
13
- print(a, src_transform)
14
- print("a, src_tranform")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/prediction_api/predictors.py CHANGED
@@ -1,18 +1,16 @@
1
  # Press the green button in the gutter to run the script.
2
- import json
3
  from pathlib import Path
4
  from typing import List
5
 
6
  import numpy as np
7
  import rasterio
8
- from PIL import Image
9
 
10
  from src import app_logger, MODEL_FOLDER
11
  from src.io.tiles_to_tiff import convert
12
  from src.io.tms2geotiff import save_geotiff_gdal
13
  from src.prediction_api.sam_onnx import SegmentAnythingONNX
14
- from src.utilities.constants import MODEL_ENCODER_NAME, ZOOM, MODEL_DECODER_NAME, ROOT
15
- from src.utilities.serialize import serialize
16
 
17
 
18
  models_dict = {"fastsam": {"instance": None}}
@@ -51,128 +49,87 @@ def load_affine_transformation_from_matrix(matrix_source_coeffs: List):
51
  app_logger.error(f"exception:{e}, check https://github.com/rasterio/affine project for updates")
52
 
53
 
54
- def samexporter_predict(bbox, prompt: list[dict], zoom: float = ZOOM, model_name: str = "fastsam") -> dict:
55
  try:
56
  from rasterio.features import shapes
57
  from geopandas import GeoDataFrame
58
 
59
  if models_dict[model_name]["instance"] is None:
60
- app_logger.info(f"missing instance model {model_name}, instantiating it now")
61
  model_instance = SegmentAnythingONNX(
62
  encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
63
  decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
64
  )
65
  models_dict[model_name]["instance"] = model_instance
66
- app_logger.info(f"using a {model_name} instance model...")
67
  models_instance = models_dict[model_name]["instance"]
68
 
69
- img, matrix = convert(
70
- bounding_box=bbox,
71
- zoom=int(zoom)
72
- )
73
-
74
- pt0, pt1 = bbox
75
- rio_output = f"/tmp/downloaded_rio_{pt0[0]}_{pt0[1]}_{pt1[0]}_{pt1[1]}.tif"
76
- save_geotiff_gdal(img, rio_output, matrix)
77
- app_logger.info(f"saved downloaded geotiff image to {rio_output}...")
78
-
79
- np_img = np.array(img)
80
- app_logger.info(f"## img type {type(np_img)}, prompt:{prompt}.")
81
-
82
- app_logger.info(f"onnxruntime input shape/size (shape if PIL) {np_img.size},"
83
- f"start to initialize SamGeo instance:")
84
- try:
85
- app_logger.info(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
86
- except Exception as e_shape:
87
- app_logger.error(f"e_shape:{e_shape}.")
88
- app_logger.info(f"use {model_name} model, ENCODER model {MODEL_ENCODER_NAME} and"
89
- f" {MODEL_DECODER_NAME} from {MODEL_FOLDER}): model instantiated, creating embedding...")
90
- embedding = models_instance.encode(np_img)
91
- app_logger.info(f"embedding created, running predict_masks...")
92
- prediction_masks = models_instance.predict_masks(embedding, prompt)
93
- app_logger.info(f"predict_masks terminated...")
94
- app_logger.info(f"predict_masks terminated, prediction masks shape:{prediction_masks.shape}, {prediction_masks.dtype}.")
95
- pt0, pt1 = bbox
96
- prediction_masks_output = f"/tmp/prediction_masks_{pt0[0]}_{pt0[1]}_{pt1[0]}_{pt1[1]}.npy"
97
- np.save(
98
- prediction_masks_output,
99
- prediction_masks, allow_pickle=True, fix_imports=True
100
- )
101
- app_logger.info(f"saved prediction_masks:{prediction_masks_output}.")
102
-
103
- mask = np.zeros((prediction_masks.shape[2], prediction_masks.shape[3]), dtype=np.uint8)
104
- app_logger.info(f"output mask shape:{mask.shape}, {mask.dtype}.")
105
- for n, m in enumerate(prediction_masks[0, :, :, :]):
106
- app_logger.info(f"## {n} mask => m shape:{mask.shape}, {mask.dtype}.")
107
- mask[m > 0.0] = 255
108
- # prediction_masks0 = prediction_masks[0]
109
- # app_logger.info(f"prediction_masks0 shape:{prediction_masks0.shape}.")
110
- #
111
- # try:
112
- # pmf = np.sum(prediction_masks0, axis=0).astype(np.uint8)
113
- # except Exception as e_sum_pmf:
114
- # app_logger.error(f"e_sum_pmf:{e_sum_pmf}.")
115
- # pmf = prediction_masks0[0]
116
- # app_logger.info(f"creating pil image from prediction mask with shape {pmf.shape}.")
117
- # pil_pmf = Image.fromarray(pmf)
118
- # pil_pmf_output = f"/tmp/pil_pmf_{pmf.shape[0]}_{pmf.shape[1]}.png"
119
- # pil_pmf.save(pil_pmf_output)
120
- # app_logger.info(f"saved pil_pmf:{pil_pmf_output}.")
121
- #
122
- # mask = np.zeros(pmf.shape, dtype=np.uint8)
123
- # mask[pmf > 0] = 255
124
-
125
- # cv2.imwrite(f"/tmp/cv2_mask_predicted_{mask.shape[0]}_{mask.shape[1]}_{mask.shape[2]}.png", mask)
126
- pil_mask = Image.fromarray(mask)
127
- pil_mask_predicted_output = f"/tmp/pil_mask_predicted_{mask.shape[0]}_{mask.shape[1]}.png"
128
- pil_mask.save(pil_mask_predicted_output)
129
- app_logger.info(f"saved pil_mask_predicted:{pil_mask_predicted_output}.")
130
-
131
- mask_unique_values, mask_unique_values_count = serialize(np.unique(mask, return_counts=True))
132
- app_logger.info(f"mask_unique_values:{mask_unique_values}.")
133
- app_logger.info(f"mask_unique_values_count:{mask_unique_values_count}.")
134
-
135
- app_logger.info(f"read geotiff:{rio_output}: create shapes_generator...")
136
- # app_logger.info(f"image/geojson transform:{transform}: create shapes_generator...")
137
- with rasterio.open(rio_output, "r", driver="GTiff") as rio_src:
138
- band = rio_src.read()
139
- try:
140
- transform = load_affine_transformation_from_matrix(matrix)
141
- app_logger.info(f"geotiff band:{band.shape}, type: {type(band)}, dtype: {band.dtype}.")
142
- app_logger.info(f"geotiff band:{mask.shape}.")
143
- app_logger.info(f"transform from matrix:{transform}.")
144
- app_logger.info(f"rio_src crs:{rio_src.crs}.")
145
- app_logger.info(f"rio_src transform:{rio_src.transform}.")
146
- except Exception as e_shape_band:
147
- app_logger.error(f"e_shape_band:{e_shape_band}.")
148
- raise e_shape_band
149
- # mask_band = band != 0
150
- shapes_generator = ({
151
- 'properties': {'raster_val': v}, 'geometry': s}
152
- for i, (s, v)
153
- # in enumerate(shapes(mask, mask=(band != 0), transform=rio_src.transform))
154
- # use mask=None to avoid using source
155
- in enumerate(shapes(mask, mask=None, transform=rio_src.transform))
156
  )
157
- app_logger.info(f"created shapes_generator.")
158
- shapes_list = list(shapes_generator)
159
- app_logger.info(f"created {len(shapes_list)} polygons.")
160
- gpd_polygonized_raster = GeoDataFrame.from_features(shapes_list, crs="EPSG:3857")
161
- app_logger.info(f"created a GeoDataFrame...")
162
- geojson = gpd_polygonized_raster.to_json(to_wgs84=True)
163
- app_logger.info(f"created geojson...")
164
-
165
- output_geojson = str(Path(ROOT) / "geojson_output.json")
166
- with open(output_geojson, "w") as jj_out:
167
- app_logger.info(f"writing geojson file to {output_geojson}.")
168
- json.dump(json.loads(geojson), jj_out)
169
- app_logger.info(f"geojson file written to {output_geojson}.")
170
-
171
- return {
172
- "geojson": geojson,
173
- "n_shapes_geojson": len(shapes_list),
174
- "n_predictions": len(prediction_masks),
175
- # "n_pixels_predictions": zip_arrays(mask_unique_values, mask_unique_values_count),
176
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  except ImportError as e:
178
  app_logger.error(f"Error trying import module:{e}.")
 
1
  # Press the green button in the gutter to run the script.
2
+ import tempfile
3
  from pathlib import Path
4
  from typing import List
5
 
6
  import numpy as np
7
  import rasterio
 
8
 
9
  from src import app_logger, MODEL_FOLDER
10
  from src.io.tiles_to_tiff import convert
11
  from src.io.tms2geotiff import save_geotiff_gdal
12
  from src.prediction_api.sam_onnx import SegmentAnythingONNX
13
+ from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME
 
14
 
15
 
16
  models_dict = {"fastsam": {"instance": None}}
 
49
  app_logger.error(f"exception:{e}, check https://github.com/rasterio/affine project for updates")
50
 
51
 
52
+ def samexporter_predict(bbox, prompt: list[dict], zoom: float, model_name: str = "fastsam") -> dict:
53
  try:
54
  from rasterio.features import shapes
55
  from geopandas import GeoDataFrame
56
 
57
  if models_dict[model_name]["instance"] is None:
58
+ app_logger.info(f"missing instance model {model_name}, instantiating it now!")
59
  model_instance = SegmentAnythingONNX(
60
  encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
61
  decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
62
  )
63
  models_dict[model_name]["instance"] = model_instance
64
+ app_logger.debug(f"using a {model_name} instance model...")
65
  models_instance = models_dict[model_name]["instance"]
66
 
67
+ with tempfile.TemporaryDirectory() as input_tmp_dir:
68
+ img, matrix = convert(
69
+ bounding_box=bbox,
70
+ zoom=int(zoom)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  )
72
+ app_logger.debug(f"## img type {type(img)} with shape/size:{img.size}, matrix:{matrix}.")
73
+
74
+ pt0, pt1 = bbox
75
+ rio_output = str(Path(input_tmp_dir) / f"downloaded_rio_{pt0[0]}_{pt0[1]}_{pt1[0]}_{pt1[1]}.tif")
76
+ app_logger.debug(f"saving downloaded geotiff image to {rio_output}...")
77
+ save_geotiff_gdal(img, rio_output, matrix)
78
+ app_logger.info(f"saved downloaded geotiff image to {rio_output}...")
79
+
80
+ np_img = np.array(img)
81
+ app_logger.info(f"## img type {type(np_img)}, prompt:{prompt}.")
82
+
83
+ app_logger.info(f"onnxruntime input shape/size (shape if PIL) {np_img.size}.")
84
+ try:
85
+ app_logger.info(f"onnxruntime input shape (NUMPY) {np_img.shape}.")
86
+ except Exception as e_shape:
87
+ app_logger.error(f"e_shape:{e_shape}.")
88
+ app_logger.info(f"instantiated model {model_name}, ENCODER {MODEL_ENCODER_NAME}, "
89
+ f"DECODER {MODEL_DECODER_NAME} from {MODEL_FOLDER}: Creating embedding...")
90
+ embedding = models_instance.encode(np_img)
91
+ app_logger.info(f"embedding created, running predict_masks with prompt {prompt}...")
92
+ prediction_masks = models_instance.predict_masks(embedding, prompt)
93
+ app_logger.info(f"Created {len(prediction_masks)} prediction_masks,"
94
+ f"shape:{prediction_masks.shape}, dtype:{prediction_masks.dtype}.")
95
+
96
+ mask = np.zeros((prediction_masks.shape[2], prediction_masks.shape[3]), dtype=np.uint8)
97
+ for n, m in enumerate(prediction_masks[0, :, :, :]):
98
+ app_logger.debug(f"## {n} mask => m shape:{mask.shape}, {mask.dtype}.")
99
+ mask[m > 0.0] = 255
100
+
101
+ app_logger.info(f"read downloaded geotiff:{rio_output} to create the shapes_generator...")
102
+
103
+ with rasterio.open(rio_output, "r", driver="GTiff") as rio_src:
104
+ band = rio_src.read()
105
+ try:
106
+ app_logger.debug(f"geotiff band:{band.shape}, type: {type(band)}, dtype: {band.dtype}.")
107
+ app_logger.debug(f"rio_src crs:{rio_src.crs}.")
108
+ app_logger.debug(f"rio_src transform:{rio_src.transform}.")
109
+ except Exception as e_shape_band:
110
+ app_logger.error(f"e_shape_band:{e_shape_band}.")
111
+ raise e_shape_band
112
+ # mask_band = band != 0
113
+ shapes_generator = ({
114
+ 'properties': {'raster_val': v}, 'geometry': s}
115
+ for i, (s, v)
116
+ # in enumerate(shapes(mask, mask=(band != 0), transform=rio_src.transform))
117
+ # use mask=None to avoid using source
118
+ in enumerate(shapes(mask, mask=None, transform=rio_src.transform))
119
+ )
120
+ app_logger.info(f"created shapes_generator, transform it to a polygon list...")
121
+ shapes_list = list(shapes_generator)
122
+ app_logger.info(f"created {len(shapes_list)} polygons.")
123
+ gpd_polygonized_raster = GeoDataFrame.from_features(shapes_list, crs="EPSG:3857")
124
+ app_logger.info(f"created a GeoDataFrame, export to geojson...")
125
+ geojson = gpd_polygonized_raster.to_json(to_wgs84=True)
126
+ app_logger.info(f"created geojson...")
127
+
128
+ return {
129
+ "geojson": geojson,
130
+ "n_shapes_geojson": len(shapes_list),
131
+ "n_predictions": len(prediction_masks),
132
+ # "n_pixels_predictions": zip_arrays(mask_unique_values, mask_unique_values_count),
133
+ }
134
  except ImportError as e:
135
  app_logger.error(f"Error trying import module:{e}.")
src/prediction_api/sam_onnx.py CHANGED
@@ -1,4 +1,3 @@
1
- import json
2
  from copy import deepcopy
3
 
4
  import cv2
@@ -6,7 +5,6 @@ import numpy as np
6
  import onnxruntime
7
 
8
  from src import app_logger
9
- from src.utilities.serialize import serialize
10
 
11
 
12
  class SegmentAnythingONNX:
@@ -149,22 +147,10 @@ class SegmentAnythingONNX:
149
  mask = masks[batch, mask_id]
150
  try:
151
  try:
152
- app_logger.info(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.")
153
  except Exception as e_mask_shape_transform_masks:
154
  app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.")
155
  # raise e_mask_shape_transform_masks
156
- output_filename = f"2_cv2img_{'_'.join([str(s) for s in mask.shape])}.npy"
157
- np.save(output_filename, np.array(mask), allow_pickle=True, fix_imports=True)
158
- app_logger.info(f"written: /tmp/{output_filename} ...")
159
- with open("/tmp/2_args.json", "w") as jj_out_dst:
160
- json.dump({
161
- "transform_matrix": serialize(transform_matrix),
162
- "M": serialize(transform_matrix[:2]),
163
- "original_size": serialize(original_size),
164
- "dsize": serialize((original_size[1], original_size[0])),
165
- "flags": cv2.INTER_LINEAR
166
- }, jj_out_dst)
167
- app_logger.info(f"written: /tmp/jj_out.json")
168
  mask = cv2.warpAffine(
169
  mask,
170
  transform_matrix[:2],
@@ -198,22 +184,6 @@ class SegmentAnythingONNX:
198
  ]
199
  )
200
  try:
201
- np_cv_image = np.array(cv_image)
202
- try:
203
- app_logger.info(f"cv_image shape_encode:{np_cv_image.shape}, dtype:{np_cv_image.dtype}.")
204
- except Exception as e_cv_image_shape_encode:
205
- app_logger.error(f"e_cv_image_shape_encode:{e_cv_image_shape_encode}.")
206
- # raise e_cv_image_shape_encode
207
- output_filename = f"/tmp/1_cv2img_{'_'.join([str(s) for s in np_cv_image.shape])}.npy"
208
- np.save(output_filename, np_cv_image, allow_pickle=True, fix_imports=True)
209
- app_logger.info(f"written: /tmp/{output_filename} ...")
210
- with open("/tmp/1_args.json", "w") as jj_out_dst:
211
- json.dump({
212
- "transform_matrix": serialize(transform_matrix),
213
- "M": serialize(transform_matrix[:2]),
214
- "flags": cv2.INTER_LINEAR
215
- }, jj_out_dst)
216
- app_logger.info(f"written: /tmp/jj_out.json")
217
  cv_image = cv2.warpAffine(
218
  cv_image,
219
  transform_matrix[:2],
 
 
1
  from copy import deepcopy
2
 
3
  import cv2
 
5
  import onnxruntime
6
 
7
  from src import app_logger
 
8
 
9
 
10
  class SegmentAnythingONNX:
 
147
  mask = masks[batch, mask_id]
148
  try:
149
  try:
150
+ app_logger.debug(f"mask_shape transform_masks:{mask.shape}, dtype:{mask.dtype}.")
151
  except Exception as e_mask_shape_transform_masks:
152
  app_logger.error(f"e_mask_shape_transform_masks:{e_mask_shape_transform_masks}.")
153
  # raise e_mask_shape_transform_masks
 
 
 
 
 
 
 
 
 
 
 
 
154
  mask = cv2.warpAffine(
155
  mask,
156
  transform_matrix[:2],
 
184
  ]
185
  )
186
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  cv_image = cv2.warpAffine(
188
  cv_image,
189
  transform_matrix[:2],