File size: 1,586 Bytes
8c9048a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import cv2
import pdb
from onehot import onehot
import torch
import matplotlib.pyplot as plt
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
class Dataset(Dataset):
def __init__(self, transform=None):
self.transform = transform
self.train_file = 'dataset/train/image'
self.mask_file = 'dataset/train/mask'
def __len__(self):
return len(os.listdir(self.train_file))
def __getitem__(self, idx):
img_name = os.listdir(self.train_file)
img_name = img_name[idx]
imgA = cv2.imread(os.path.join(self.train_file, img_name))
# print(img_name)
# plt.imshow(imgA)
imgA = cv2.resize(imgA, (160, 160))
imgB = cv2.imread(self.mask_file+'/'+img_name[0:7]+'_gt.png', 0) #change to your mask format
imgB = cv2.resize(imgB, (160, 160))
imgB = imgB/255
imgB = imgB.astype('uint8')
imgB = onehot(imgB, 2)
imgB = imgB.swapaxes(0, 2).swapaxes(1, 2)
imgB = torch.FloatTensor(imgB)
#print(imgB.shape)
if self.transform:
imgA = self.transform(imgA)
item = {'A':imgA, 'B':imgB}
return item
training_data = Dataset(transform)
dataloader = DataLoader(training_data, batch_size=4, shuffle=True, num_workers=4)
if __name__ =='__main__':
for batch in dataloader:
print(len(dataloader))
break
|