|
import os |
|
from PIL import Image |
|
from torchvision import transforms |
|
from torch.utils.data import Dataset |
|
|
|
def find_mask_file(image_path, mask_dir, mask_extensions=['.png', '.jpg', '.jpeg']): |
|
base_name = os.path.splitext(os.path.basename(image_path))[0] |
|
for ext in mask_extensions: |
|
mask_path = os.path.join(mask_dir, base_name + ext) |
|
if os.path.exists(mask_path): |
|
return mask_path |
|
return None |
|
|
|
class SegmentationDataset(Dataset): |
|
def __init__(self, image_dir, mask_dir, transform=None): |
|
self.image_dir = image_dir |
|
self.mask_dir = mask_dir |
|
self.transform = transform |
|
self.image_filenames = os.listdir(image_dir) |
|
|
|
def __len__(self): |
|
return len(self.image_filenames) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.image_dir, self.image_filenames[idx]) |
|
mask_path = find_mask_file(img_path, self.mask_dir) |
|
image = Image.open(img_path).convert("RGB") |
|
mask = Image.open(mask_path).convert("L") |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
mask = self.transform(mask) |
|
|
|
return image, mask |
|
|
|
def transform_img(): |
|
transform = transforms.Compose([ |
|
transforms.Resize((128, 128)), |
|
transforms.ToTensor() |
|
]) |
|
return transform |
|
|
|
if __name__ == "__main__": |
|
print("Dataset class") |