x-lai
Release training script
6144294
raw
history blame
9.82 kB
import argparse
import json
import math
import os
import random
import shortuuid
import torch
from llava import LlavaLlamaForCausalLM
from llava.conversation import conv_templates
from llava.utils import disable_torch_init
from PIL import Image
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionModel,
StoppingCriteria,
)
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def patch_config(config):
patch_dict = {
"use_mm_proj": True,
"mm_vision_tower": "openai/clip-vit-large-patch14",
"mm_hidden_size": 1024,
}
cfg = AutoConfig.from_pretrained(config)
if not hasattr(cfg, "mm_vision_tower"):
print(
f"`mm_vision_tower` not found in `{config}`, applying patch and save to disk."
)
for k, v in patch_dict.items():
setattr(cfg, k, v)
cfg.save_pretrained(config)
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if args.mm_projector is None:
patch_config(model_name)
model = LlavaLlamaForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16
).cuda()
image_processor = CLIPImageProcessor.from_pretrained(
model.config.mm_vision_tower, torch_dtype=torch.float16
)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
vision_tower = model.model.vision_tower[0]
vision_tower.to(device="cuda", dtype=torch.float16)
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN]
)[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
(
vision_config.im_start_token,
vision_config.im_end_token,
) = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
)
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
else:
# in case of using a pretrained model with only a MLP projector weights
model = LlavaLlamaForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16
).cuda()
vision_tower = CLIPVisionModel.from_pretrained(
args.vision_tower, torch_dtype=torch.float16
).cuda()
image_processor = CLIPImageProcessor.from_pretrained(
args.vision_tower, torch_dtype=torch.float16
)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN]
)[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
(
vision_config.im_start_token,
vision_config.im_end_token,
) = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
)
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
mm_projector = torch.nn.Linear(
vision_config.hidden_size, model.config.hidden_size
)
mm_projector_weights = torch.load(args.mm_projector, map_location="cpu")
mm_projector.load_state_dict(
{k.split(".")[-1]: v for k, v in mm_projector_weights.items()}
)
model.model.mm_projector = mm_projector.cuda().half()
model.model.vision_tower = [vision_tower]
questions = [
json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")
]
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
answers_file = os.path.expanduser(args.answers_file)
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
ans_file = open(answers_file, "w")
for i, line in enumerate(tqdm(questions)):
idx = line["question_id"]
image_file = line["image"]
qs = line["text"]
cur_prompt = qs
if mm_use_im_start_end:
qs = (
qs
+ "\n"
+ DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ DEFAULT_IM_END_TOKEN
)
else:
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
if args.conv_mode == "simple_legacy":
qs += "\n\n### Response:"
# conv = default_conversation.copy()
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
image = Image.open(os.path.join(args.image_folder, image_file))
# image.save(os.path.join(save_image_folder, image_file))
image_tensor = image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# new stopping implementation
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(
output_ids[:, self.start_len :], skip_special_tokens=True
)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
keywords = ["###"]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
do_sample=True,
temperature=0.7,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria],
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (
(input_ids != output_ids[:, :input_token_len]).sum().item()
)
if n_diff_input_output > 0:
print(
f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids"
)
outputs = tokenizer.batch_decode(
output_ids[:, input_token_len:], skip_special_tokens=True
)[0]
if args.conv_mode == "simple_legacy" or args.conv_mode == "simple":
while True:
cur_len = len(outputs)
outputs = outputs.strip()
for pattern in ["###", "Assistant:", "Response:"]:
if outputs.startswith(pattern):
outputs = outputs[len(pattern) :].strip()
if len(outputs) == cur_len:
break
try:
index = outputs.index(conv.sep)
except ValueError:
outputs += conv.sep
index = outputs.index(conv.sep)
outputs = outputs[:index].strip()
ans_id = shortuuid.uuid()
ans_file.write(
json.dumps(
{
"question_id": idx,
"prompt": cur_prompt,
"text": outputs,
"answer_id": ans_id,
"model_id": model_name,
"metadata": {},
}
)
+ "\n"
)
ans_file.flush()
ans_file.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--mm-projector", type=str, default=None)
parser.add_argument("--vision-tower", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default="simple")
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
args = parser.parse_args()
eval_model(args)