Spaces:
Sleeping
Sleeping
Update tasks/image.py
Browse files- tasks/image.py +22 -4
tasks/image.py
CHANGED
@@ -13,7 +13,6 @@ from PIL import Image
|
|
13 |
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
|
14 |
import cv2
|
15 |
from tqdm import tqdm
|
16 |
-
from dataset import WildfireSmokeDataset
|
17 |
from torch.utils.data import DataLoader
|
18 |
|
19 |
from dotenv import load_dotenv
|
@@ -30,6 +29,19 @@ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobile
|
|
30 |
model.load_state_dict(torch.load(model_path))
|
31 |
model.eval()
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def get_bounding_boxes_from_mask(mask):
|
34 |
"""Extract bounding boxes from a binary mask."""
|
35 |
pred_boxes = []
|
@@ -39,7 +51,7 @@ def get_bounding_boxes_from_mask(mask):
|
|
39 |
x, y, w, h = cv2.boundingRect(contour)
|
40 |
pred_boxes.append((x, y, x + w, y + h))
|
41 |
return pred_boxes
|
42 |
-
|
43 |
def parse_boxes(annotation_string):
|
44 |
"""Parse multiple boxes from a single annotation string.
|
45 |
Each box has 5 values: class_id, x_center, y_center, width, height"""
|
@@ -130,6 +142,10 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
130 |
for example in test_dataset:
|
131 |
# Extract image and annotations
|
132 |
image = example["image"]
|
|
|
|
|
|
|
|
|
133 |
annotation = example.get("annotations", "").strip()
|
134 |
|
135 |
|
@@ -154,8 +170,10 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
154 |
|
155 |
probabilities = torch.sigmoid(logits)
|
156 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
157 |
-
predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
158 |
-
|
|
|
|
|
159 |
# Extract predicted bounding boxes
|
160 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
161 |
pred_boxes.append(predicted_boxes)
|
|
|
13 |
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
|
14 |
import cv2
|
15 |
from tqdm import tqdm
|
|
|
16 |
from torch.utils.data import DataLoader
|
17 |
|
18 |
from dotenv import load_dotenv
|
|
|
29 |
model.load_state_dict(torch.load(model_path))
|
30 |
model.eval()
|
31 |
|
32 |
+
def preprocess(image):
|
33 |
+
image = image.resize((512,512))
|
34 |
+
|
35 |
+
# Convert to BGR
|
36 |
+
image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
|
37 |
+
image = Image.fromarray(image)
|
38 |
+
image = image.resize(self.image_size)
|
39 |
+
|
40 |
+
# Normalize pixel values to [0, 1]
|
41 |
+
image = np.array(image, dtype=np.float32) / 255.0
|
42 |
+
|
43 |
+
return image
|
44 |
+
|
45 |
def get_bounding_boxes_from_mask(mask):
|
46 |
"""Extract bounding boxes from a binary mask."""
|
47 |
pred_boxes = []
|
|
|
51 |
x, y, w, h = cv2.boundingRect(contour)
|
52 |
pred_boxes.append((x, y, x + w, y + h))
|
53 |
return pred_boxes
|
54 |
+
|
55 |
def parse_boxes(annotation_string):
|
56 |
"""Parse multiple boxes from a single annotation string.
|
57 |
Each box has 5 values: class_id, x_center, y_center, width, height"""
|
|
|
142 |
for example in test_dataset:
|
143 |
# Extract image and annotations
|
144 |
image = example["image"]
|
145 |
+
|
146 |
+
original_shape = (len(image), len(image[0]))
|
147 |
+
image = preprocess(image)
|
148 |
+
|
149 |
annotation = example.get("annotations", "").strip()
|
150 |
|
151 |
|
|
|
170 |
|
171 |
probabilities = torch.sigmoid(logits)
|
172 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
173 |
+
# predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
174 |
+
predicted_mask_resized = cv2.resize(predicted_mask, original_shape, interpolation=cv2.INTER_NEAREST)
|
175 |
+
|
176 |
+
|
177 |
# Extract predicted bounding boxes
|
178 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
179 |
pred_boxes.append(predicted_boxes)
|