quandao92's picture
Upload 48 files
71d05bb verified
import torch.utils.data as data
import json
import random
from PIL import Image
import numpy as np
import torch
import os
def generate_class_info(dataset_name, mode='train'):
class_name_map_class_id = {}
if dataset_name == 'mvtec':
# obj_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill',
# 'transistor', 'metal_nut', 'screw', 'toothbrush', 'zipper', 'tile', 'wood']
obj_list = ['bottle']
elif dataset_name == '4inlab':
if mode=='train':
obj_list = ['shinpyung'] # With training
elif mode=='test':
obj_list = ['shinpyung'] # With testing
elif dataset_name == 'task1':
if mode=='train':
obj_list = ['cup']
elif dataset_name == 'task2':
if mode=='train':
obj_list = ['fire']
elif dataset_name == 'smoke_cloud':
if mode=='train':
obj_list = ['fire']
for k, index in zip(obj_list, range(len(obj_list))):
class_name_map_class_id[k] = index
return obj_list, class_name_map_class_id
class Dataset_test(data.Dataset):
def __init__(self, root, transform, target_transform, dataset_name, mode="test"):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='test')
def __len__(self):
return self.length
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
if os.path.isdir(os.path.join(self.root, mask_path)):
# just for classification not report error
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
# transforms
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
class Dataset_train(data.Dataset):
def __init__(self, root, transform, target_transform, dataset_name, mode="train"):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='train')
def __len__(self):
return self.length
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
if os.path.isdir(os.path.join(self.root, mask_path)):
# just for classification not report error
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
# transforms
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}