Guill-Bla commited on
Commit
f82759b
·
verified ·
1 Parent(s): cccaa05

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +7 -4
tasks/image.py CHANGED
@@ -35,13 +35,15 @@ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
35
  model.eval()
36
 
37
  def preprocess(image):
38
- image = image.resize((512,512))
39
- image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
40
- # Normalize pixel values to [0, 1]
41
  image = np.array(image, dtype=np.float32) / 255.0
42
 
 
 
43
  return image
44
 
 
45
  def get_bounding_boxes_from_mask(mask):
46
  """Extract bounding boxes from a binary mask."""
47
  pred_boxes = []
@@ -165,6 +167,7 @@ async def evaluate_image(request: ImageEvaluationRequest):
165
  # Model Inference
166
  # image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
167
  image_input = feature_extractor(images=image, return_tensors="pt", padding=True).pixel_values
 
168
 
169
  with torch.no_grad():
170
  outputs = model(pixel_values=image_input)
@@ -173,7 +176,7 @@ async def evaluate_image(request: ImageEvaluationRequest):
173
  probabilities = torch.sigmoid(logits)
174
  predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
175
  # predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
176
- predicted_mask_resized = cv2.resize(predicted_mask, original_shape, interpolation=cv2.INTER_NEAREST)
177
 
178
 
179
  # Extract predicted bounding boxes
 
35
  model.eval()
36
 
37
  def preprocess(image):
38
+ image = image.resize((512, 512))
39
+ image = np.array(image)[:, :, ::-1] # RGB to BGR
 
40
  image = np.array(image, dtype=np.float32) / 255.0
41
 
42
+ # Convert back to PIL Image to maintain compatibility with feature extractor
43
+ image = Image.fromarray((image * 255).astype(np.uint8))
44
  return image
45
 
46
+
47
  def get_bounding_boxes_from_mask(mask):
48
  """Extract bounding boxes from a binary mask."""
49
  pred_boxes = []
 
167
  # Model Inference
168
  # image_input = feature_extractor(images=image, return_tensors="pt").pixel_values
169
  image_input = feature_extractor(images=image, return_tensors="pt", padding=True).pixel_values
170
+ image_input = feature_extractor(images=[image], return_tensors="pt", padding=True).pixel_values
171
 
172
  with torch.no_grad():
173
  outputs = model(pixel_values=image_input)
 
176
  probabilities = torch.sigmoid(logits)
177
  predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
178
  # predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
179
+ predicted_mask_resized = cv2.resize(predicted_mask, original_shape[::-1], interpolation=cv2.INTER_NEAREST)
180
 
181
 
182
  # Extract predicted bounding boxes