Spaces:
Sleeping
Sleeping
Update tasks/image.py
Browse files- tasks/image.py +22 -20
tasks/image.py
CHANGED
@@ -35,13 +35,16 @@ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
|
35 |
model.eval()
|
36 |
|
37 |
def preprocess(image):
|
|
|
38 |
image = image.resize((512, 512))
|
39 |
-
|
|
|
|
|
40 |
image = np.array(image, dtype=np.float32) / 255.0
|
41 |
|
42 |
-
#
|
43 |
-
|
44 |
-
|
45 |
|
46 |
|
47 |
def get_bounding_boxes_from_mask(mask):
|
@@ -145,16 +148,11 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
145 |
# Extract image and annotations
|
146 |
image = example["image"]
|
147 |
|
148 |
-
original_shape = image.size
|
149 |
-
image = preprocess(image)
|
150 |
-
|
151 |
annotation = example.get("annotations", "").strip()
|
152 |
-
|
153 |
-
|
154 |
has_smoke = len(annotation) > 0
|
155 |
true_labels.append(1 if has_smoke else 0)
|
156 |
|
157 |
-
|
158 |
if has_smoke:
|
159 |
image_true_boxes = parse_boxes(annotation)
|
160 |
if image_true_boxes:
|
@@ -165,26 +163,30 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
165 |
true_boxes_list.append([])
|
166 |
|
167 |
# Model Inference
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
171 |
with torch.no_grad():
|
172 |
outputs = model(pixel_values=image_input)
|
173 |
logits = outputs.logits
|
174 |
-
|
|
|
175 |
probabilities = torch.sigmoid(logits)
|
176 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
177 |
-
# predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
178 |
predicted_mask_resized = cv2.resize(predicted_mask, original_shape[::-1], interpolation=cv2.INTER_NEAREST)
|
179 |
-
|
180 |
|
181 |
-
# Extract
|
182 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
183 |
pred_boxes.append(predicted_boxes)
|
184 |
-
|
185 |
-
#
|
|
|
186 |
print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
|
187 |
-
predictions.append(1 if len(predicted_boxes) > 0 else 0)
|
188 |
|
189 |
|
190 |
# Filter only valid box pairs
|
|
|
35 |
model.eval()
|
36 |
|
37 |
def preprocess(image):
|
38 |
+
# Ensure input image is resized to a fixed size (512, 512)
|
39 |
image = image.resize((512, 512))
|
40 |
+
|
41 |
+
# Convert to NumPy and ensure BGR normalization
|
42 |
+
image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
|
43 |
image = np.array(image, dtype=np.float32) / 255.0
|
44 |
|
45 |
+
# Return as a PIL Image for feature extractor compatibility
|
46 |
+
return Image.fromarray((image * 255).astype(np.uint8))
|
47 |
+
|
48 |
|
49 |
|
50 |
def get_bounding_boxes_from_mask(mask):
|
|
|
148 |
# Extract image and annotations
|
149 |
image = example["image"]
|
150 |
|
151 |
+
original_shape = image.size
|
|
|
|
|
152 |
annotation = example.get("annotations", "").strip()
|
|
|
|
|
153 |
has_smoke = len(annotation) > 0
|
154 |
true_labels.append(1 if has_smoke else 0)
|
155 |
|
|
|
156 |
if has_smoke:
|
157 |
image_true_boxes = parse_boxes(annotation)
|
158 |
if image_true_boxes:
|
|
|
163 |
true_boxes_list.append([])
|
164 |
|
165 |
# Model Inference
|
166 |
+
|
167 |
+
# Preprocess image
|
168 |
+
image = preprocess(image)
|
169 |
+
|
170 |
+
# Ensure correct feature extraction
|
171 |
+
image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
|
172 |
+
|
173 |
+
# Perform inference
|
174 |
with torch.no_grad():
|
175 |
outputs = model(pixel_values=image_input)
|
176 |
logits = outputs.logits
|
177 |
+
|
178 |
+
# Threshold and process the segmentation mask
|
179 |
probabilities = torch.sigmoid(logits)
|
180 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
|
|
181 |
predicted_mask_resized = cv2.resize(predicted_mask, original_shape[::-1], interpolation=cv2.INTER_NEAREST)
|
|
|
182 |
|
183 |
+
# Extract bounding boxes
|
184 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
185 |
pred_boxes.append(predicted_boxes)
|
186 |
+
|
187 |
+
# Smoke prediction based on bounding box presence
|
188 |
+
predictions.append(1 if len(predicted_boxes) > 0 else 0)
|
189 |
print(f"Prediction : {1 if len(predicted_boxes) > 0 else 0}")
|
|
|
190 |
|
191 |
|
192 |
# Filter only valid box pairs
|