Spaces:
Running
Running
""" | |
Base class for our zero-shot anomaly detection dataset | |
""" | |
import json | |
import os | |
import random | |
import numpy as np | |
import torch.utils.data as data | |
from PIL import Image | |
import cv2 | |
from config import DATA_ROOT | |
class DataSolver: | |
def __init__(self, root, clsnames): | |
self.root = root | |
self.clsnames = clsnames | |
self.path = os.path.join(root, 'meta.json') | |
def run(self): | |
with open(self.path, 'r') as f: | |
info = json.load(f) | |
info_required = dict(train={}, test={}) | |
for cls in self.clsnames: | |
for k in info.keys(): | |
info_required[k][cls] = info[k][cls] | |
return info_required | |
class BaseDataset(data.Dataset): | |
def __init__(self, clsnames, transform, target_transform, root, aug_rate=0., training=True): | |
self.root = root | |
self.transform = transform | |
self.target_transform = target_transform | |
self.aug_rate = aug_rate | |
self.training = training | |
self.data_all = [] | |
self.cls_names = clsnames | |
solver = DataSolver(root, clsnames) | |
meta_info = solver.run() | |
self.meta_info = meta_info['test'] # Only utilize the test dataset for both training and testing | |
for cls_name in self.cls_names: | |
self.data_all.extend(self.meta_info[cls_name]) | |
self.length = len(self.data_all) | |
def __len__(self): | |
return self.length | |
def combine_img(self, cls_name): | |
""" | |
From April-GAN: https://github.com/ByChelsea/VAND-APRIL-GAN | |
Here we combine four images into a single image for data augmentation. | |
""" | |
img_info = random.sample(self.meta_info[cls_name], 4) | |
img_ls = [] | |
mask_ls = [] | |
for data in img_info: | |
img_path = os.path.join(self.root, data['img_path']) | |
mask_path = os.path.join(self.root, data['mask_path']) | |
img = Image.open(img_path).convert('RGB') | |
img_ls.append(img) | |
if not data['anomaly']: | |
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') | |
else: | |
img_mask = np.array(Image.open(mask_path).convert('L')) > 0 | |
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') | |
mask_ls.append(img_mask) | |
# Image | |
image_width, image_height = img_ls[0].size | |
result_image = Image.new("RGB", (2 * image_width, 2 * image_height)) | |
for i, img in enumerate(img_ls): | |
row = i // 2 | |
col = i % 2 | |
x = col * image_width | |
y = row * image_height | |
result_image.paste(img, (x, y)) | |
# Mask | |
result_mask = Image.new("L", (2 * image_width, 2 * image_height)) | |
for i, img in enumerate(mask_ls): | |
row = i // 2 | |
col = i % 2 | |
x = col * image_width | |
y = row * image_height | |
result_mask.paste(img, (x, y)) | |
return result_image, result_mask | |
def __getitem__(self, index): | |
data = self.data_all[index] | |
img_path = os.path.join(self.root, data['img_path']) | |
mask_path = os.path.join(self.root, data['mask_path']) | |
cls_name = data['cls_name'] | |
anomaly = data['anomaly'] | |
random_number = random.random() | |
if self.training and random_number < self.aug_rate: | |
img, img_mask = self.combine_img(cls_name) | |
else: | |
if img_path.endswith('.tif'): | |
img = cv2.imread(img_path) | |
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
else: | |
img = Image.open(img_path).convert('RGB') | |
if anomaly == 0: | |
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') | |
else: | |
if data['mask_path']: | |
img_mask = np.array(Image.open(mask_path).convert('L')) > 0 | |
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') | |
else: | |
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') | |
# Transforms | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None and img_mask is not None: | |
img_mask = self.target_transform(img_mask) | |
if img_mask is None: | |
img_mask = [] | |
return { | |
'img': img, | |
'img_mask': img_mask, | |
'cls_name': cls_name, | |
'anomaly': anomaly, | |
'img_path': img_path | |
} | |