Guill-Bla commited on
Commit
ada3b45
·
verified ·
1 Parent(s): d39874c

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +10 -9
tasks/image.py CHANGED
@@ -61,12 +61,14 @@ class SmokeDataset(torch.utils.data.Dataset):
61
  image = example["image"]
62
  annotation = example.get("annotations", "").strip()
63
 
64
- # Preprocess and extract features directly within the dataset
65
- image = preprocess(image) # Apply resizing and other preprocessing
66
- image_input = feature_extractor(images=image, return_tensors="pt").pixel_values.squeeze(0)
67
-
68
- return image_input, annotation
69
 
 
 
 
 
 
70
 
71
 
72
  def preprocess_batch(images):
@@ -184,10 +186,9 @@ async def evaluate_image(request: ImageEvaluationRequest):
184
  true_boxes_list = []
185
 
186
  for batch_images, batch_annotations in dataloader:
187
-
188
- batch_images = batch_images.to(device) # Move to the correct device if using GPU
189
-
190
- # Perform inference
191
  with torch.no_grad():
192
  outputs = model(pixel_values=batch_images)
193
  logits = outputs.logits
 
61
  image = example["image"]
62
  annotation = example.get("annotations", "").strip()
63
 
64
+ # Resize image and preprocess
65
+ image = preprocess(image) # Apply resizing and preprocessing
 
 
 
66
 
67
+ # Extract features with padding set to True
68
+ features = feature_extractor(images=image, return_tensors="pt", padding=True)
69
+
70
+ # Return pixel values directly as tensors
71
+ return features.pixel_values.squeeze(0), annotation
72
 
73
 
74
  def preprocess_batch(images):
 
186
  true_boxes_list = []
187
 
188
  for batch_images, batch_annotations in dataloader:
189
+
190
+ batch_images = batch_images.to(device)
191
+
 
192
  with torch.no_grad():
193
  outputs = model(pixel_values=batch_images)
194
  logits = outputs.logits