clipsegmulticlass / src /data_processing.py
BioMike's picture
Upload 3 files
4875d48 verified
raw
history blame
2.16 kB
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import numpy as np
class SingleClassSegmentationDataset(Dataset):
def __init__(self, dataset, class_labels, image_size=352, transform=None):
self.items = dataset
self.class_labels = class_labels
self.image_size = image_size
self.transform = transform
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
image = Image.open(item["img_path"]).convert("RGB")
mask = Image.open(item["mask_path"]).convert("L")
class_name = item["label"]
class_index = self.class_labels.index(class_name)
background_index = 0
mask_np = np.array(mask) > 0
final_mask = np.full(mask_np.shape, background_index, dtype=np.uint8)
final_mask[mask_np] = class_index
image = image.resize((self.image_size, self.image_size), Image.BILINEAR)
final_mask = Image.fromarray(final_mask).resize((self.image_size, self.image_size), Image.NEAREST)
if self.transform:
image, final_mask = self.transform(image, final_mask)
return {
"image": image,
"labels": torch.from_numpy(np.array(final_mask)).long()
}
class SegmentationCollator:
def __init__(self, processor, class_labels):
self.processor = processor
self.class_labels = class_labels
def __call__(self, batch):
images = [item["image"] for item in batch]
labels = [item["labels"] for item in batch]
prompts = self.class_labels * len(images)
expanded_images = [img for img in images for _ in self.class_labels]
inputs = self.processor(
images=expanded_images,
text=prompts,
return_tensors="pt",
padding=True,
truncation=True
)
return {
"pixel_values": inputs["pixel_values"],
"input_ids": inputs["input_ids"],
"labels": torch.stack(labels)
}