File size: 2,761 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
this script will convert the original data format to tsv format, in which jpg files are converted to base64 strings and
saved in a resolution of 512(shorter side).
each contains 3 columns: input, edit, seed0_0, seed0_1, ..., seedN_0, seedN_1
each tsv file should contain 1000 samples.
"""

import argparse
import base64
import io
import json
import os
from multiprocessing import Process

from PIL import Image
from tqdm import tqdm


def save_tsv(args, i, sub_seeds_list):
    with open(os.path.join(args.output_dir, f'{str(i).zfill(4)}.tsv'), 'w') as f:
        for name, seeds in tqdm(sub_seeds_list, desc=f'processing {i}th tsv file', leave=False):
            # load prompt
            prompt = json.load(open(os.path.join(args.data_dir, name, 'prompt.json')))
            # load images
            images = [Image.open(os.path.join(args.data_dir, name, f'{seed}_{j}.jpg')).convert('RGB') for seed in seeds
                      for j in range(2)]
            # resize the shorter side to 512
            images = [im.resize((512, int(512 / im.size[0] * im.size[1])) if im.size[0] < im.size[1] else
                                (int(512 / im.size[1] * im.size[0]), 512)) for im in images]
            # encode image using base64
            for j, im in enumerate(images):
                buffer = io.BytesIO()
                im.save(buffer, format='PNG')
                images[j] = base64.b64encode(buffer.getvalue()).decode('utf-8')

            # write to tsv file
            f.write('\t'.join([
                prompt['input'].replace('\t', '').replace('\n', '').strip(),
                prompt['edit'].replace('\t', '').replace('\n', '').strip(),
                *images
            ]) + '\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default='/path/to/clip-filtered-dataset/')
    parser.add_argument('--output-dir', type=str, default='/path/to/output-dir/')
    parser.add_argument('--max-image', type=int, default=1000)
    parser.add_argument('--num-process', type=int, default=64)

    args = parser.parse_args()

    # load seeds
    seeds_list = json.load(open(os.path.join(args.data_dir, 'seeds.json')))

    # split seeds into 1000 samples per tsv file
    seeds_list = [seeds_list[i:i + args.max_image] for i in range(0, len(seeds_list), args.max_image)]

    # save tsv files
    processes = []
    for i, sub_seeds_list in enumerate(seeds_list):
        p = Process(target=save_tsv, args=(args, i, sub_seeds_list))
        p.start()
        processes.append(p)
        if len(processes) == args.num_process:
            for p in processes:
                p.join()
            processes = []

    for p in processes:
        p.join()

    print('done.')