import base64 import io import multiprocessing import os import random from argparse import ArgumentParser from multiprocessing import Process import numpy as np import requests import torch import torch.nn.functional as F from PIL import Image from scipy.ndimage import label, find_objects, grey_dilation from torch.utils.data import Dataset, DataLoader from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Blip2Processor, Blip2ForConditionalGeneration, \ CLIPSegProcessor, CLIPSegForImageSegmentation Image.MAX_IMAGE_PIXELS = 1000000000 INSTRUCTION_KEY = "### Instruction:" INPUT_KEY = "### Input:" RESPONSE_KEY = "### Response:" INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." PROMPT_FOR_GENERATION_FORMAT = """{intro} {instruction_key} Extract all objects mentioned in the caption and separate them using commas. Exclude background elements (site, location, environment) and only include foreground objects. Ensure that only nouns are included and exclude adjectives entirely. {input_key} {input} {response_key} """.format( intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", input_key=INPUT_KEY, input="{input}", response_key=RESPONSE_KEY, ) @torch.no_grad() def save_tsv(args, shard_id, shard, device): os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7" random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.set_device(device) model_dtype = torch.float16 # blip2 blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b") blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=model_dtype) blip2_model.eval().to(device) # mpt mpt_config = AutoConfig.from_pretrained('mosaicml/mpt-7b-instruct', trust_remote_code=True) mpt_config.init_device = device mpt_config.attn_config['attn_impl'] = args.attn_impl mpt_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') mpt_tokenizer.pad_token = mpt_tokenizer.eos_token mpt_tokenizer.padding_side = 'left' mpt_model = AutoModelForCausalLM.from_pretrained('mosaicml/mpt-7b-instruct', config=mpt_config, torch_dtype=model_dtype, trust_remote_code=True) mpt_model.eval() mpt_generate_kwargs = { 'max_new_tokens': args.max_new_tokens, 'temperature': args.temperature, 'top_p': args.top_p, 'top_k': args.top_k, 'repetition_penalty': args.repetition_penalty, 'no_repeat_ngram_size': args.no_repeat_ngram_size, 'use_cache': args.use_cache, 'do_sample': False if args.temperature == 0 else args.do_sample, 'eos_token_id': mpt_tokenizer.eos_token_id, 'pad_token_id': mpt_tokenizer.pad_token_id, } # clipseg clipseg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined", torch_dtype=model_dtype) clipseg_model.eval().to(device) cnt = 0 for image in tqdm(shard): if image is None: continue if cnt % 1000 == 0: # close previous file if any if cnt > 0: f.close() f = open(os.path.join(args.output_dir, f"cnt_{args.machine_id}_{shard_id}_{cnt // 1000}.tsv"), "w", encoding='utf-8') cnt += 1 blip2_input = blip2_processor(images=image, return_tensors="pt").to(device, model_dtype) blip2_gen = blip2_model.generate(**blip2_input) caption = blip2_processor.batch_decode(blip2_gen, skip_special_tokens=True)[0] \ .replace('\t', '').replace('\n', '').strip() # tag extraction prompt = PROMPT_FOR_GENERATION_FORMAT.format(input=caption) # Run HF generate mpt_input = mpt_tokenizer(prompt, return_tensors='pt', padding=True) for key, value in mpt_input.items(): mpt_input[key] = value.to(device) mpt_gen = mpt_model.generate( input_ids=mpt_input['input_ids'], attention_mask=mpt_input['attention_mask'], **mpt_generate_kwargs, ) tags = mpt_tokenizer.batch_decode(mpt_gen, skip_special_tokens=True)[0][len(prompt):] if '#' in tags: continue tags = tags.split(",") tags = [tag.replace('\t', '').replace('\n', '').strip() for tag in tags] tags = [tag for tag in tags if len(tag) > 0 and tag in caption] if len(tags) == 0: continue clipseg_input = clipseg_processor(text=tags, images=[image] * len(tags), padding=True, return_tensors="pt") for key, value in clipseg_input.items(): clipseg_input[key] = value.to(device) if value.dtype == torch.float32: clipseg_input[key] = value.to(device, model_dtype) # predict clipseg_gen = clipseg_model(**clipseg_input).logits if len(tags) == 1: clipseg_gen = clipseg_gen.unsqueeze(0) image_size = image.height # interpolate to original size clipseg_gen = F.interpolate(clipseg_gen.unsqueeze(1), size=image_size, mode='bilinear') masks = torch.sigmoid(clipseg_gen).squeeze(1) masks = masks.cpu().numpy() sub_images = [] tags_to_keep = [] # save the masked image for mask_id, mask in enumerate(masks): image_array = np.array(image) thresholded_mask = mask > args.threshold if thresholded_mask.max() == 0: continue thresholded_mask = grey_dilation(thresholded_mask, size=(image_size // 100, image_size // 100)) labeled_matrix, num_features = label(thresholded_mask) regions = find_objects(labeled_matrix) sizes = [np.sum(thresholded_mask[region]) for region in regions] max_index = np.argmax(sizes) max_region = regions[max_index] thresholded_mask[labeled_matrix != (max_index + 1)] = False tags_to_keep.append(tags[mask_id]) # Determine the dimensions of the region y_start, y_stop = max_region[0].start, max_region[0].stop x_start, x_stop = max_region[1].start, max_region[1].stop height = y_stop - y_start width = x_stop - x_start # Calculate the desired side length for a square region side_length = max(height, width) # Calculate the center of the region center_y = (y_start + y_stop) // 2 center_x = (x_start + x_stop) // 2 # Calculate the new boundaries for the region new_y_start = center_y - (side_length // 2) new_y_stop = new_y_start + side_length new_x_start = center_x - (side_length // 2) new_x_stop = new_x_start + side_length # Adjust the boundaries if they exceed the image boundaries if new_y_start < 0: new_y_start = 0 new_y_stop = side_length elif new_y_stop > image_array.shape[0]: new_y_start = image_array.shape[0] - side_length new_y_stop = image_array.shape[0] if new_x_start < 0: new_x_start = 0 new_x_stop = side_length elif new_x_stop > image_array.shape[1]: new_x_start = image_array.shape[1] - side_length new_x_stop = image_array.shape[1] # Create a new mask with the adjusted boundaries object_image = image_array[new_y_start:new_y_stop, new_x_start:new_x_stop] max_region_mask = thresholded_mask[new_y_start:new_y_stop, new_x_start:new_x_stop] masked_image = object_image.copy() masked_image[~max_region_mask] = 255 object_image = Image.fromarray(object_image).resize((512, 512)) masked_image = Image.fromarray(masked_image).resize((512, 512)) sub_images.extend([object_image, masked_image]) if len(sub_images) == 0: continue image = image.resize((512, 512)) # encode image using base64 buffer = io.BytesIO() image.save(buffer, format='PNG') image = base64.b64encode(buffer.getvalue()).decode('utf-8') for j, im in enumerate(sub_images): buffer = io.BytesIO() im.save(buffer, format='PNG') sub_images[j] = base64.b64encode(buffer.getvalue()).decode('utf-8') # write to tsv file f.write('\t'.join([ caption, ','.join(tags_to_keep), image, *sub_images ]) + '\n') class OpenImageDataset(Dataset): def __init__(self, url_data): self.data = url_data def __len__(self): return len(self.data) def __getitem__(self, idx): try: items = self.data[idx].split(',') image = Image.open(requests.get(items[2], stream=True).raw).convert('RGB') # caption width, height = image.size shortest_side = min(width, height) left = (width - shortest_side) // 2 top = (height - shortest_side) // 2 right = left + shortest_side bottom = top + shortest_side image = image.crop((left, top, right, bottom)) return image except: return None def collate_fn(batch): return batch[0] if batch is not None else None def main(): """Parse commandline arguments.""" parser = ArgumentParser() parser.add_argument('--data-dir', type=str, default='/path/to/image_ids_and_rotation.csv') parser.add_argument('--output-dir', type=str, default='/path/to/output-dir/') parser.add_argument('--num-process', type=int, default=8) parser.add_argument('--cuda-device', type=list, default=[0, 1, 2, 3, 4, 5, 6, 7]) parser.add_argument('--num-machine', type=int, default=1) parser.add_argument('--machine-id', type=int, default=0) parser.add_argument('--max-seq-len', type=int, default=None) parser.add_argument('--max-new-tokens', type=int, default=10) parser.add_argument('--temperature', type=float, default=1.0) parser.add_argument('--top-k', type=int, default=50) parser.add_argument('--top-p', type=float, default=0.95) parser.add_argument('--repetition-penalty', type=float, default=1.0) parser.add_argument('--no-repeat-ngram-size', type=int, default=0) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--do-sample', type=bool, default=True) parser.add_argument('--use-cache', type=bool, default=True) parser.add_argument('--trust-remote-code', type=bool, default=True) parser.add_argument('--attn-impl', type=str, default='torch') parser.add_argument('--threshold', type=float, default=0.3) args = parser.parse_args() os.environ['TOKENIZERS_PARALLELISM'] = 'false' with open(args.data_dir, 'r', encoding='utf8') as f: url_data = f.read().strip().split('\n') # split into 8 machine, and pick the part of machine_id url_data = url_data[args.machine_id::args.num_machine] # split url data into shards url_data = [url_data[i::args.num_process] for i in range(args.num_process)] dataloaders = [ DataLoader( OpenImageDataset(url_data[i]), batch_size=1, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True, drop_last=False, prefetch_factor=4, collate_fn=collate_fn ) for i in range(args.num_process) ] multiprocessing.set_start_method('spawn') processes = [] for shard_id, shard in enumerate(dataloaders): p = Process( target=save_tsv, args=( args, shard_id, shard, torch.device('cuda:{}'.format(args.cuda_device[shard_id % len(args.cuda_device)])) ) ) p.start() processes.append(p) for p in processes: p.join() print('Done!') if __name__ == '__main__': main()