File size: 2,095 Bytes
6ecc7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import glob
import os
from PIL import Image
import random
import numpy as np

from torch import nn
from torchvision import transforms
from torch.utils import data as data
import torch.nn.functional as F

from .realesrgan import RealESRGAN_degradation

class PairedCaptionDataset(data.Dataset):
    def __init__(

            self,

            root_folders=None,

            tokenizer=None,

            gt_ratio=0, # let lr is gt

    ):
        super(PairedCaptionDataset, self).__init__()

        self.gt_ratio = gt_ratio
        with open(root_folders, 'r') as f:
            self.gt_list = [line.strip() for line in f.readlines()]

        self.img_preproc = transforms.Compose([
            transforms.RandomCrop((512, 512)),
            transforms.Resize((512, 512)),
            transforms.RandomHorizontalFlip(),
            ])

        self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda')
        self.tokenizer = tokenizer

    
    def tokenize_caption(self, caption=""):
        inputs = self.tokenizer(
            caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )

        return inputs.input_ids

    def __getitem__(self, index):

        gt_path = self.gt_list[index]
        gt_img = Image.open(gt_path).convert('RGB')
        gt_img = self.img_preproc(gt_img)

        gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True)

        if random.random() < self.gt_ratio:
            lq_img = gt_img
        else:
            lq_img = img_t

        # no caption used
        lq_caption = ''

        example = dict()
        example["conditioning_pixel_values"] = lq_img.squeeze(0) # [0, 1]
        example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1]
        example["input_caption"] = self.tokenize_caption(caption=lq_caption).squeeze(0)

        lq_img = lq_img.squeeze()

        return example

    def __len__(self):
        return len(self.gt_list)