File size: 12,443 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import base64
import io
import multiprocessing
import os
import random
from argparse import ArgumentParser
from multiprocessing import Process

import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.ndimage import label, find_objects, grey_dilation
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Blip2Processor, Blip2ForConditionalGeneration, \
    CLIPSegProcessor, CLIPSegForImageSegmentation

Image.MAX_IMAGE_PIXELS = 1000000000

INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "### Input:"
RESPONSE_KEY = "### Response:"
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
PROMPT_FOR_GENERATION_FORMAT = """{intro}

{instruction_key}
Extract all objects mentioned in the caption and separate them using commas. Exclude background elements (site, location, environment) and only include foreground objects. Ensure that only nouns are included and exclude adjectives entirely.

{input_key}
{input}

{response_key}
""".format(
    intro=INTRO_BLURB,
    instruction_key=INSTRUCTION_KEY,
    instruction="{instruction}",
    input_key=INPUT_KEY,
    input="{input}",
    response_key=RESPONSE_KEY,
)


@torch.no_grad()
def save_tsv(args, shard_id, shard, device):
    os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7"
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.set_device(device)
    model_dtype = torch.float16
    # blip2
    blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b")
    blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=model_dtype)
    blip2_model.eval().to(device)

    # mpt
    mpt_config = AutoConfig.from_pretrained('mosaicml/mpt-7b-instruct', trust_remote_code=True)
    mpt_config.init_device = device
    mpt_config.attn_config['attn_impl'] = args.attn_impl

    mpt_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
    mpt_tokenizer.pad_token = mpt_tokenizer.eos_token
    mpt_tokenizer.padding_side = 'left'
    mpt_model = AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-instruct', config=mpt_config,
                                                     torch_dtype=model_dtype, trust_remote_code=True)
    mpt_model.eval()

    mpt_generate_kwargs = {
        'max_new_tokens': args.max_new_tokens,
        'temperature': args.temperature,
        'top_p': args.top_p,
        'top_k': args.top_k,
        'repetition_penalty': args.repetition_penalty,
        'no_repeat_ngram_size': args.no_repeat_ngram_size,
        'use_cache': args.use_cache,
        'do_sample': False if args.temperature == 0 else args.do_sample,
        'eos_token_id': mpt_tokenizer.eos_token_id,
        'pad_token_id': mpt_tokenizer.pad_token_id,
    }

    # clipseg
    clipseg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
    clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined", torch_dtype=model_dtype)
    clipseg_model.eval().to(device)

    cnt = 0

    for image in tqdm(shard):
        if image is None:
            continue
        if cnt % 1000 == 0:
            # close previous file if any
            if cnt > 0:
                f.close()
            f = open(os.path.join(args.output_dir, f"cnt_{args.machine_id}_{shard_id}_{cnt // 1000}.tsv"), "w",
                     encoding='utf-8')
        cnt += 1

        blip2_input = blip2_processor(images=image, return_tensors="pt").to(device, model_dtype)

        blip2_gen = blip2_model.generate(**blip2_input)
        caption = blip2_processor.batch_decode(blip2_gen, skip_special_tokens=True)[0] \
            .replace('\t', '').replace('\n', '').strip()

        # tag extraction
        prompt = PROMPT_FOR_GENERATION_FORMAT.format(input=caption)

        # Run HF generate
        mpt_input = mpt_tokenizer(prompt, return_tensors='pt', padding=True)
        for key, value in mpt_input.items():
            mpt_input[key] = value.to(device)
        mpt_gen = mpt_model.generate(
            input_ids=mpt_input['input_ids'],
            attention_mask=mpt_input['attention_mask'],
            **mpt_generate_kwargs,
        )
        tags = mpt_tokenizer.batch_decode(mpt_gen, skip_special_tokens=True)[0][len(prompt):]

        if '#' in tags:
            continue
        tags = tags.split(",")

        tags = [tag.replace('\t', '').replace('\n', '').strip() for tag in tags]
        tags = [tag for tag in tags if len(tag) > 0 and tag in caption]

        if len(tags) == 0:
            continue

        clipseg_input = clipseg_processor(text=tags, images=[image] * len(tags), padding=True, return_tensors="pt")
        for key, value in clipseg_input.items():
            clipseg_input[key] = value.to(device)
            if value.dtype == torch.float32:
                clipseg_input[key] = value.to(device, model_dtype)

        # predict
        clipseg_gen = clipseg_model(**clipseg_input).logits

        if len(tags) == 1:
            clipseg_gen = clipseg_gen.unsqueeze(0)

        image_size = image.height

        # interpolate to original size
        clipseg_gen = F.interpolate(clipseg_gen.unsqueeze(1), size=image_size, mode='bilinear')
        masks = torch.sigmoid(clipseg_gen).squeeze(1)
        masks = masks.cpu().numpy()

        sub_images = []
        tags_to_keep = []

        # save the masked image
        for mask_id, mask in enumerate(masks):
            image_array = np.array(image)
            thresholded_mask = mask > args.threshold

            if thresholded_mask.max() == 0:
                continue

            thresholded_mask = grey_dilation(thresholded_mask, size=(image_size // 100, image_size // 100))
            labeled_matrix, num_features = label(thresholded_mask)
            regions = find_objects(labeled_matrix)
            sizes = [np.sum(thresholded_mask[region]) for region in regions]
            max_index = np.argmax(sizes)
            max_region = regions[max_index]
            thresholded_mask[labeled_matrix != (max_index + 1)] = False

            tags_to_keep.append(tags[mask_id])

            # Determine the dimensions of the region
            y_start, y_stop = max_region[0].start, max_region[0].stop
            x_start, x_stop = max_region[1].start, max_region[1].stop
            height = y_stop - y_start
            width = x_stop - x_start

            # Calculate the desired side length for a square region
            side_length = max(height, width)

            # Calculate the center of the region
            center_y = (y_start + y_stop) // 2
            center_x = (x_start + x_stop) // 2

            # Calculate the new boundaries for the region
            new_y_start = center_y - (side_length // 2)
            new_y_stop = new_y_start + side_length
            new_x_start = center_x - (side_length // 2)
            new_x_stop = new_x_start + side_length

            # Adjust the boundaries if they exceed the image boundaries
            if new_y_start < 0:
                new_y_start = 0
                new_y_stop = side_length
            elif new_y_stop > image_array.shape[0]:
                new_y_start = image_array.shape[0] - side_length
                new_y_stop = image_array.shape[0]

            if new_x_start < 0:
                new_x_start = 0
                new_x_stop = side_length
            elif new_x_stop > image_array.shape[1]:
                new_x_start = image_array.shape[1] - side_length
                new_x_stop = image_array.shape[1]

            # Create a new mask with the adjusted boundaries
            object_image = image_array[new_y_start:new_y_stop, new_x_start:new_x_stop]
            max_region_mask = thresholded_mask[new_y_start:new_y_stop, new_x_start:new_x_stop]

            masked_image = object_image.copy()
            masked_image[~max_region_mask] = 255

            object_image = Image.fromarray(object_image).resize((512, 512))
            masked_image = Image.fromarray(masked_image).resize((512, 512))
            sub_images.extend([object_image, masked_image])

        if len(sub_images) == 0:
            continue

        image = image.resize((512, 512))

        # encode image using base64
        buffer = io.BytesIO()
        image.save(buffer, format='PNG')
        image = base64.b64encode(buffer.getvalue()).decode('utf-8')

        for j, im in enumerate(sub_images):
            buffer = io.BytesIO()
            im.save(buffer, format='PNG')
            sub_images[j] = base64.b64encode(buffer.getvalue()).decode('utf-8')

        # write to tsv file
        f.write('\t'.join([
            caption,
            ','.join(tags_to_keep),
            image,
            *sub_images
        ]) + '\n')


class OpenImageDataset(Dataset):
    def __init__(self, url_data):
        self.data = url_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            items = self.data[idx].split(',')
            image = Image.open(requests.get(items[2], stream=True).raw).convert('RGB')
            # caption
            width, height = image.size
            shortest_side = min(width, height)
            left = (width - shortest_side) // 2
            top = (height - shortest_side) // 2
            right = left + shortest_side
            bottom = top + shortest_side
            image = image.crop((left, top, right, bottom))
            return image
        except:
            return None


def collate_fn(batch):
    return batch[0] if batch is not None else None


def main():
    """Parse commandline arguments."""
    parser = ArgumentParser()
    parser.add_argument('--data-dir', type=str,
                        default='/path/to/image_ids_and_rotation.csv')
    parser.add_argument('--output-dir', type=str, default='/path/to/output-dir/')
    parser.add_argument('--num-process', type=int, default=8)
    parser.add_argument('--cuda-device', type=list, default=[0, 1, 2, 3, 4, 5, 6, 7])
    parser.add_argument('--num-machine', type=int, default=1)
    parser.add_argument('--machine-id', type=int, default=0)

    parser.add_argument('--max-seq-len', type=int, default=None)
    parser.add_argument('--max-new-tokens', type=int, default=10)

    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--top-k', type=int, default=50)
    parser.add_argument('--top-p', type=float, default=0.95)
    parser.add_argument('--repetition-penalty', type=float, default=1.0)
    parser.add_argument('--no-repeat-ngram-size', type=int, default=0)

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--do-sample', type=bool, default=True)
    parser.add_argument('--use-cache', type=bool, default=True)
    parser.add_argument('--trust-remote-code', type=bool, default=True)
    parser.add_argument('--attn-impl', type=str, default='torch')
    parser.add_argument('--threshold', type=float, default=0.3)
    args = parser.parse_args()

    os.environ['TOKENIZERS_PARALLELISM'] = 'false'

    with open(args.data_dir, 'r', encoding='utf8') as f:
        url_data = f.read().strip().split('\n')

    # split into 8 machine, and pick the part of machine_id
    url_data = url_data[args.machine_id::args.num_machine]

    # split url data into shards
    url_data = [url_data[i::args.num_process] for i in range(args.num_process)]

    dataloaders = [
        DataLoader(
            OpenImageDataset(url_data[i]),
            batch_size=1,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
            drop_last=False,
            prefetch_factor=4,
            collate_fn=collate_fn
        )
        for i in range(args.num_process)
    ]

    multiprocessing.set_start_method('spawn')
    processes = []

    for shard_id, shard in enumerate(dataloaders):
        p = Process(
            target=save_tsv,
            args=(
                args,
                shard_id,
                shard,
                torch.device('cuda:{}'.format(args.cuda_device[shard_id % len(args.cuda_device)]))
            )
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    print('Done!')


if __name__ == '__main__':
    main()