Guill-Bla commited on
Commit
f7ad336
·
verified ·
1 Parent(s): 05104a4

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +22 -4
tasks/image.py CHANGED
@@ -13,7 +13,6 @@ from PIL import Image
13
  from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
14
  import cv2
15
  from tqdm import tqdm
16
- from dataset import WildfireSmokeDataset
17
  from torch.utils.data import DataLoader
18
 
19
  from dotenv import load_dotenv
@@ -30,6 +29,19 @@ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobile
30
  model.load_state_dict(torch.load(model_path))
31
  model.eval()
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def get_bounding_boxes_from_mask(mask):
34
  """Extract bounding boxes from a binary mask."""
35
  pred_boxes = []
@@ -39,7 +51,7 @@ def get_bounding_boxes_from_mask(mask):
39
  x, y, w, h = cv2.boundingRect(contour)
40
  pred_boxes.append((x, y, x + w, y + h))
41
  return pred_boxes
42
-
43
  def parse_boxes(annotation_string):
44
  """Parse multiple boxes from a single annotation string.
45
  Each box has 5 values: class_id, x_center, y_center, width, height"""
@@ -130,6 +142,10 @@ async def evaluate_image(request: ImageEvaluationRequest):
130
  for example in test_dataset:
131
  # Extract image and annotations
132
  image = example["image"]
 
 
 
 
133
  annotation = example.get("annotations", "").strip()
134
 
135
 
@@ -154,8 +170,10 @@ async def evaluate_image(request: ImageEvaluationRequest):
154
 
155
  probabilities = torch.sigmoid(logits)
156
  predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
157
- predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
158
-
 
 
159
  # Extract predicted bounding boxes
160
  predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
161
  pred_boxes.append(predicted_boxes)
 
13
  from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
14
  import cv2
15
  from tqdm import tqdm
 
16
  from torch.utils.data import DataLoader
17
 
18
  from dotenv import load_dotenv
 
29
  model.load_state_dict(torch.load(model_path))
30
  model.eval()
31
 
32
+ def preprocess(image):
33
+ image = image.resize((512,512))
34
+
35
+ # Convert to BGR
36
+ image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
37
+ image = Image.fromarray(image)
38
+ image = image.resize(self.image_size)
39
+
40
+ # Normalize pixel values to [0, 1]
41
+ image = np.array(image, dtype=np.float32) / 255.0
42
+
43
+ return image
44
+
45
  def get_bounding_boxes_from_mask(mask):
46
  """Extract bounding boxes from a binary mask."""
47
  pred_boxes = []
 
51
  x, y, w, h = cv2.boundingRect(contour)
52
  pred_boxes.append((x, y, x + w, y + h))
53
  return pred_boxes
54
+
55
  def parse_boxes(annotation_string):
56
  """Parse multiple boxes from a single annotation string.
57
  Each box has 5 values: class_id, x_center, y_center, width, height"""
 
142
  for example in test_dataset:
143
  # Extract image and annotations
144
  image = example["image"]
145
+
146
+ original_shape = (len(image), len(image[0]))
147
+ image = preprocess(image)
148
+
149
  annotation = example.get("annotations", "").strip()
150
 
151
 
 
170
 
171
  probabilities = torch.sigmoid(logits)
172
  predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
173
+ # predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
174
+ predicted_mask_resized = cv2.resize(predicted_mask, original_shape, interpolation=cv2.INTER_NEAREST)
175
+
176
+
177
  # Extract predicted bounding boxes
178
  predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
179
  pred_boxes.append(predicted_boxes)