import argparse import glob import json import os from io import BytesIO import numpy as np import requests import torch import tqdm from llava.conversation import SeparatorStyle, conv_templates from llava.model import * from llava.model.utils import KeywordsStoppingCriteria from llava.utils import disable_torch_init from PIL import Image from transformers import ( AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, StoppingCriteria, ) DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" def load_image(image_file): if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_file).convert("RGB") return image classes = [ "wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed", "windowpane", "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass", "clock", "flag", ] def eval_model(args): # Model disable_torch_init() model_name = os.path.expanduser(args.model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) if "mpt" in model_name.lower(): model = LlavaMPTForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True, ).cuda() else: # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() model = LlavaLlamaForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) # .cuda() image_processor = CLIPImageProcessor.from_pretrained( model.config.mm_vision_tower, torch_dtype=torch.float16 ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) vision_tower = model.get_model().vision_tower[0] if vision_tower.device.type == "meta": vision_tower = CLIPVisionModel.from_pretrained( vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, ).cuda() model.get_model().vision_tower[0] = vision_tower else: vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( [DEFAULT_IMAGE_PATCH_TOKEN] )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: ( vision_config.im_start_token, vision_config.im_end_token, ) = tokenizer.convert_tokens_to_ids( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 # paths for all images images = sorted( glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg") ) start, end = args.range.split(",") start, end = int(start), int(end) images = images[start:end] results = [] for i, image_file in enumerate(tqdm.tqdm(images)): # if i == 2: # break # if i % 100 == 0: # print("i: {}, len(images): {}".format(i, len(images))) print("i: {}, len(images): {}".format(i, len(images))) image = load_image(image_file) image_tensor = image_processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] image_tensor = image_tensor.unsqueeze(0).half().cuda() prompt_list = [] output_list = [] label_file = image_file.replace("images", "annotations").replace(".jpg", ".png") label = Image.open(label_file) label = np.array(label) label_unique = np.unique(label) for label in label_unique: if label == 0: continue class_id = label - 1 class_label = classes[class_id] input_conv = "Can you describe the {} in this image?".format(class_label) qs = input_conv prompt_list.append(input_conv) # qs = args.query if mm_use_im_start_end: qs = ( qs + "\n" + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN ) else: qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt_multimodal" else: conv_mode = "multimodal" if args.conv_mode is not None and conv_mode != args.conv_mode: print( "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( conv_mode, args.conv_mode, args.conv_mode ) ) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) # image = load_image(args.image_file) # image = load_image(image_file) # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] input_ids = torch.as_tensor(inputs.input_ids).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=512, # 1024, stopping_criteria=[stopping_criteria], ) input_token_len = input_ids.shape[1] n_diff_input_output = ( (input_ids != output_ids[:, :input_token_len]).sum().item() ) if n_diff_input_output > 0: print( f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" ) outputs = tokenizer.batch_decode( output_ids[:, input_token_len:], skip_special_tokens=True )[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[: -len(stop_str)] outputs = outputs.strip() # print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file)) # image_id = image_file.split("/")[-1] # results.append({'image_id': image_id, 'input': input_conv, 'output': outputs}) output_list.append(outputs) image_id = image_file.split("/")[-1].split(".")[0] with open( "/mnt/proj74/xinlai/LLM/LLaVA/generated/{}.json".format(image_id), "w+" ) as f: json.dump( { "image_id": image_id, "input_list": prompt_list, "output_list": output_list, }, f, ) # with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f: # json.dump(results, f) # print(outputs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--query", type=str, required=True) parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--range", type=str, default=None) args = parser.parse_args() eval_model(args)