Guill-Bla commited on
Commit
d39874c
·
verified ·
1 Parent(s): bed8f16

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +20 -21
tasks/image.py CHANGED
@@ -38,7 +38,18 @@ model.eval()
38
 
39
  from torch.utils.data import Dataset
40
 
41
- class SmokeDataset(Dataset):
 
 
 
 
 
 
 
 
 
 
 
42
  def __init__(self, dataset):
43
  self.dataset = dataset
44
 
@@ -49,26 +60,14 @@ class SmokeDataset(Dataset):
49
  example = self.dataset[idx]
50
  image = example["image"]
51
  annotation = example.get("annotations", "").strip()
 
 
 
 
52
 
53
- # Resize and preprocess the image
54
- image = image.resize((512, 512))
55
- image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
56
- image = np.array(image, dtype=np.float32) / 255.0
57
-
58
- # Return both the preprocessed image tensor and annotation
59
- return torch.tensor(image, dtype=torch.float32).permute(2, 0, 1), annotation
60
-
61
-
62
- def preprocess(image):
63
- # Ensure input image is resized to a fixed size (512, 512)
64
- image = image.resize((512, 512))
65
 
66
- # Convert to NumPy and ensure BGR normalization
67
- image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
68
- image = np.array(image, dtype=np.float32) / 255.0
69
 
70
- # Return as a PIL Image for feature extractor compatibility
71
- return Image.fromarray((image * 255).astype(np.uint8))
72
 
73
  def preprocess_batch(images):
74
  """
@@ -185,12 +184,12 @@ async def evaluate_image(request: ImageEvaluationRequest):
185
  true_boxes_list = []
186
 
187
  for batch_images, batch_annotations in dataloader:
188
- # image_inputs = feature_extractor(images=batch_images, return_tensors="pt", padding=True).pixel_values
189
- image_inputs = feature_extractor(images=[img.permute(1, 2, 0).numpy() for img in batch_images], return_tensors="pt", padding=True).pixel_values
190
 
191
  # Perform inference
192
  with torch.no_grad():
193
- outputs = model(pixel_values=image_inputs)
194
  logits = outputs.logits
195
 
196
  probabilities = torch.sigmoid(logits)
 
38
 
39
  from torch.utils.data import Dataset
40
 
41
+ def preprocess(image):
42
+ # Ensure input image is resized to a fixed size (512, 512)
43
+ image = image.resize((512, 512))
44
+
45
+ # Convert to NumPy and ensure BGR normalization
46
+ image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
47
+ image = np.array(image, dtype=np.float32) / 255.0
48
+
49
+ # Return as a PIL Image for feature extractor compatibility
50
+ return Image.fromarray((image * 255).astype(np.uint8))
51
+
52
+ class SmokeDataset(torch.utils.data.Dataset):
53
  def __init__(self, dataset):
54
  self.dataset = dataset
55
 
 
60
  example = self.dataset[idx]
61
  image = example["image"]
62
  annotation = example.get("annotations", "").strip()
63
+
64
+ # Preprocess and extract features directly within the dataset
65
+ image = preprocess(image) # Apply resizing and other preprocessing
66
+ image_input = feature_extractor(images=image, return_tensors="pt").pixel_values.squeeze(0)
67
 
68
+ return image_input, annotation
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
70
 
 
 
71
 
72
  def preprocess_batch(images):
73
  """
 
184
  true_boxes_list = []
185
 
186
  for batch_images, batch_annotations in dataloader:
187
+
188
+ batch_images = batch_images.to(device) # Move to the correct device if using GPU
189
 
190
  # Perform inference
191
  with torch.no_grad():
192
+ outputs = model(pixel_values=batch_images)
193
  logits = outputs.logits
194
 
195
  probabilities = torch.sigmoid(logits)