Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,095 Bytes
6ecc7d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import glob
import os
from PIL import Image
import random
import numpy as np
from torch import nn
from torchvision import transforms
from torch.utils import data as data
import torch.nn.functional as F
from .realesrgan import RealESRGAN_degradation
class PairedCaptionDataset(data.Dataset):
def __init__(
self,
root_folders=None,
tokenizer=None,
gt_ratio=0, # let lr is gt
):
super(PairedCaptionDataset, self).__init__()
self.gt_ratio = gt_ratio
with open(root_folders, 'r') as f:
self.gt_list = [line.strip() for line in f.readlines()]
self.img_preproc = transforms.Compose([
transforms.RandomCrop((512, 512)),
transforms.Resize((512, 512)),
transforms.RandomHorizontalFlip(),
])
self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda')
self.tokenizer = tokenizer
def tokenize_caption(self, caption=""):
inputs = self.tokenizer(
caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
def __getitem__(self, index):
gt_path = self.gt_list[index]
gt_img = Image.open(gt_path).convert('RGB')
gt_img = self.img_preproc(gt_img)
gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True)
if random.random() < self.gt_ratio:
lq_img = gt_img
else:
lq_img = img_t
# no caption used
lq_caption = ''
example = dict()
example["conditioning_pixel_values"] = lq_img.squeeze(0) # [0, 1]
example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1]
example["input_caption"] = self.tokenize_caption(caption=lq_caption).squeeze(0)
lq_img = lq_img.squeeze()
return example
def __len__(self):
return len(self.gt_list) |