anonymous-upload-neurips-2025's picture
Upload 1784 files
8c9048a verified
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import cv2
import pdb
from onehot import onehot
import torch
import matplotlib.pyplot as plt
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
class Dataset(Dataset):
def __init__(self, transform=None):
self.transform = transform
self.train_file = 'dataset/validation/image'
self.mask_file = 'dataset/validation/mask'
def __len__(self):
return len(os.listdir(self.train_file))
def __getitem__(self, idx):
img_name = os.listdir(self.train_file)
img_name = img_name[idx]
imgA = cv2.imread(os.path.join(self.train_file, img_name))
# print(img_name)
# plt.imshow(imgA)
imgA = cv2.resize(imgA, (160, 160))
imgB = cv2.imread(self.mask_file+'/'+img_name[0:7]+'_gt.png', 0)
imgB = cv2.resize(imgB, (160, 160))
imgB = imgB/255
imgB = imgB.astype('uint8')
imgB = onehot(imgB, 2)
imgB = imgB.swapaxes(0, 2).swapaxes(1, 2)
imgB = torch.FloatTensor(imgB)
#print(imgB.shape)
if self.transform:
imgA = self.transform(imgA)
item = {'A':imgA, 'B':imgB}
return item
validation_data = Dataset(transform)
dataloader_val = DataLoader(validation_data, batch_size=4, shuffle=True, num_workers=4)
if __name__ =='__main__':
for batch in dataloader:
print(len(dataloader))
break