Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import glob | |
| import os | |
| from multiprocessing.dummy import Pool as ThreadPool | |
| from PIL import Image | |
| from torchvision.transforms.functional import to_tensor | |
| from ..Models import * | |
| class ImageSplitter: | |
| # key points: | |
| # Boarder padding and over-lapping img splitting to avoid the instability of edge value | |
| # Thanks Waifu2x's autorh nagadomi for suggestions (https://github.com/nagadomi/waifu2x/issues/238) | |
| def __init__(self, seg_size=48, scale_factor=2, boarder_pad_size=3): | |
| self.seg_size = seg_size | |
| self.scale_factor = scale_factor | |
| self.pad_size = boarder_pad_size | |
| self.height = 0 | |
| self.width = 0 | |
| self.upsampler = nn.Upsample(scale_factor=scale_factor, mode="bilinear") | |
| def split_img_tensor(self, pil_img, scale_method=Image.BILINEAR, img_pad=0): | |
| # resize image and convert them into tensor | |
| img_tensor = to_tensor(pil_img).unsqueeze(0) | |
| img_tensor = nn.ReplicationPad2d(self.pad_size)(img_tensor) | |
| batch, channel, height, width = img_tensor.size() | |
| self.height = height | |
| self.width = width | |
| if scale_method is not None: | |
| img_up = pil_img.resize( | |
| (2 * pil_img.size[0], 2 * pil_img.size[1]), scale_method | |
| ) | |
| img_up = to_tensor(img_up).unsqueeze(0) | |
| img_up = nn.ReplicationPad2d(self.pad_size * self.scale_factor)(img_up) | |
| patch_box = [] | |
| # avoid the residual part is smaller than the padded size | |
| if ( | |
| height % self.seg_size < self.pad_size | |
| or width % self.seg_size < self.pad_size | |
| ): | |
| self.seg_size += self.scale_factor * self.pad_size | |
| # split image into over-lapping pieces | |
| for i in range(self.pad_size, height, self.seg_size): | |
| for j in range(self.pad_size, width, self.seg_size): | |
| part = img_tensor[ | |
| :, | |
| :, | |
| (i - self.pad_size) : min( | |
| i + self.pad_size + self.seg_size, height | |
| ), | |
| (j - self.pad_size) : min(j + self.pad_size + self.seg_size, width), | |
| ] | |
| if img_pad > 0: | |
| part = nn.ZeroPad2d(img_pad)(part) | |
| if scale_method is not None: | |
| # part_up = self.upsampler(part) | |
| part_up = img_up[ | |
| :, | |
| :, | |
| self.scale_factor | |
| * (i - self.pad_size) : min( | |
| i + self.pad_size + self.seg_size, height | |
| ) | |
| * self.scale_factor, | |
| self.scale_factor | |
| * (j - self.pad_size) : min( | |
| j + self.pad_size + self.seg_size, width | |
| ) | |
| * self.scale_factor, | |
| ] | |
| patch_box.append((part, part_up)) | |
| else: | |
| patch_box.append(part) | |
| return patch_box | |
| def merge_img_tensor(self, list_img_tensor): | |
| out = torch.zeros( | |
| (1, 3, self.height * self.scale_factor, self.width * self.scale_factor) | |
| ) | |
| img_tensors = copy.copy(list_img_tensor) | |
| rem = self.pad_size * 2 | |
| pad_size = self.scale_factor * self.pad_size | |
| seg_size = self.scale_factor * self.seg_size | |
| height = self.scale_factor * self.height | |
| width = self.scale_factor * self.width | |
| for i in range(pad_size, height, seg_size): | |
| for j in range(pad_size, width, seg_size): | |
| part = img_tensors.pop(0) | |
| part = part[:, :, rem:-rem, rem:-rem] | |
| # might have error | |
| if len(part.size()) > 3: | |
| _, _, p_h, p_w = part.size() | |
| out[:, :, i : i + p_h, j : j + p_w] = part | |
| # out[:,:, | |
| # self.scale_factor*i:self.scale_factor*i+p_h, | |
| # self.scale_factor*j:self.scale_factor*j+p_w] = part | |
| out = out[:, :, rem:-rem, rem:-rem] | |
| return out | |
| def load_single_image( | |
| img_file, | |
| up_scale=False, | |
| up_scale_factor=2, | |
| up_scale_method=Image.BILINEAR, | |
| zero_padding=False, | |
| ): | |
| img = Image.open(img_file).convert("RGB") | |
| out = to_tensor(img).unsqueeze(0) | |
| if zero_padding: | |
| out = nn.ZeroPad2d(zero_padding)(out) | |
| if up_scale: | |
| size = tuple(map(lambda x: x * up_scale_factor, img.size)) | |
| img_up = img.resize(size, up_scale_method) | |
| img_up = to_tensor(img_up).unsqueeze(0) | |
| out = (out, img_up) | |
| return out | |
| def standardize_img_format(img_folder): | |
| def process(img_file): | |
| img_path = os.path.dirname(img_file) | |
| img_name, _ = os.path.basename(img_file).split(".") | |
| out = os.path.join(img_path, img_name + ".JPEG") | |
| os.rename(img_file, out) | |
| list_imgs = [] | |
| for i in ["png", "jpeg", "jpg"]: | |
| list_imgs.extend(glob.glob(img_folder + "**/*." + i, recursive=True)) | |
| print("Found {} images.".format(len(list_imgs))) | |
| pool = ThreadPool(4) | |
| pool.map(process, list_imgs) | |
| pool.close() | |
| pool.join() | |