Spaces:
Running
Running
import os | |
from PIL import Image | |
import torch | |
from torch.utils.data import Dataset | |
import numpy as np | |
class SingleClassSegmentationDataset(Dataset): | |
def __init__(self, dataset, class_labels, image_size=352, transform=None): | |
self.items = dataset | |
self.class_labels = class_labels | |
self.image_size = image_size | |
self.transform = transform | |
def __len__(self): | |
return len(self.items) | |
def __getitem__(self, idx): | |
item = self.items[idx] | |
image = Image.open(item["img_path"]).convert("RGB") | |
mask = Image.open(item["mask_path"]).convert("L") | |
class_name = item["label"] | |
class_index = self.class_labels.index(class_name) | |
background_index = 0 | |
mask_np = np.array(mask) > 0 | |
final_mask = np.full(mask_np.shape, background_index, dtype=np.uint8) | |
final_mask[mask_np] = class_index | |
image = image.resize((self.image_size, self.image_size), Image.BILINEAR) | |
final_mask = Image.fromarray(final_mask).resize((self.image_size, self.image_size), Image.NEAREST) | |
if self.transform: | |
image, final_mask = self.transform(image, final_mask) | |
return { | |
"image": image, | |
"labels": torch.from_numpy(np.array(final_mask)).long() | |
} | |
class SegmentationCollator: | |
def __init__(self, processor, class_labels): | |
self.processor = processor | |
self.class_labels = class_labels | |
def __call__(self, batch): | |
images = [item["image"] for item in batch] | |
labels = [item["labels"] for item in batch] | |
prompts = self.class_labels * len(images) | |
expanded_images = [img for img in images for _ in self.class_labels] | |
inputs = self.processor( | |
images=expanded_images, | |
text=prompts, | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
) | |
return { | |
"pixel_values": inputs["pixel_values"], | |
"input_ids": inputs["input_ids"], | |
"labels": torch.stack(labels) | |
} | |