model1 / inference_deployment /predict_v09.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
raw
history blame
6.66 kB
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from llava.model import LlavaThothForCausalLM
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
from llava.conversation import conv_templates
import torch
from llava.mm_utils import tokenizer_image_token, process_images_v2, KeywordsStoppingCriteria
import numpy as np
from PIL import Image
import os
NUM_SEGMENTS = 10
def load_model(model_path, device_map):
kwargs = {"device_map": device_map}
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaThothForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device_map)
image_processor = Blip2ImageTrainProcessor(
image_size=model.config.img_size,
is_training=False)
model.to(torch.float16)
return model, tokenizer, image_processor
def generate_input_ids(tokenizer):
conv = conv_templates['thoth'].copy()
qs = "Describe the following video in detail."
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
return input_ids, conv
def generate_images(frame_folder, image_processor, model_cfg):
images = load_frames(frame_folder)
if len(images) > NUM_SEGMENTS:
images = uniform_sample(images, NUM_SEGMENTS)
return process_images_v2(images, image_processor, model_cfg)
def uniform_sample(frames, num_segments):
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
def load_frames(frames_dir):
results = []
image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')]
image_files = sorted(image_files, key=lambda img: img[0])
for frame_name in image_files:
image_path = f"{frames_dir}/{frame_name[1]}"
image = Image.open(image_path).convert('RGB')
results.append(image)
return results
class MASPVisionWrapper(torch.nn.Module):
def __init__(self, vision_tower, qformer, projector, query_tokens, frame_position_encoding, ln_vision):
super().__init__()
self.vision_tower = vision_tower
self.qformer = qformer
self.projector = projector
self.query_tokens = query_tokens
self.ln_vision = ln_vision
self.frame_position_encoding = frame_position_encoding
def forward(self, images):
# images: [num_frames, patches, 3, image_size, image_size]
image_features = self.vision_tower(images.flatten(0, 1))
image_features = self.ln_vision(image_features)
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) # [num_frames * num_patches, 256]
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768]
dtype_ = self.vision_tower.dtype
image_features = self.qformer.bert(
query_embeds= query_tokens.to(dtype_),
encoder_hidden_states=image_features.to(dtype_),
encoder_attention_mask=attn_mask,
return_dict=True
).last_hidden_state.to(dtype_)
frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1)
frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches]
image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) #[num_frames, 1, 768]
return self.projector(image_features)
def inference(model_path, frame_folder):
# prepare
model, tokenizer, image_processor = load_model(model_path, device_map={"":0})
input_ids, conv = generate_input_ids(tokenizer)
images = generate_images(frame_folder, image_processor, model.config).to(model.device).half() # [num_frames, patches, 3, image_size, image_size]
vision_module = MASPVisionWrapper(
vision_tower=model.get_vision_tower(),
qformer=model.get_qformer(),
projector=model.get_model().mm_projector,
query_tokens=model.get_query_tokens(),
frame_position_encoding=model.get_frame_position_encoding(),
ln_vision=model.get_ln_vision(),
)
stop_str = conv.sep if conv.sep2 is None else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
input_ids = input_ids[0].to(model.device) # [token_len]
# infernece
with torch.inference_mode():
# get image feature
image_features = vision_module(images).flatten(0, 1) # [num_frames * num_patches * num_query_token, 4096]
# concat with text features
vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0]
pre_text_token = model.get_model().embed_tokens(input_ids[:vision_token_indice])
post_text_token = model.get_model().embed_tokens(input_ids[vision_token_indice+1:])
inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze(0) # [1, num_token, 4096]
# llm.generate
output_ids = model.generate_from_base_class(
inputs_embeds=inputs_embeds,
do_sample=True,
temperature=0.01,
top_p=None,
num_beams=1,
max_new_tokens=1024,
pad_token_id=tokenizer.eos_token_id,
use_cache=True,
stopping_criteria=[stopping_criteria]
)
output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
output = output.strip()
print(output)
if __name__ == '__main__':
model_path = '/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/llava-thothv2_mar_release_all_data'
frame_folder = '/mnt/bn/yukunfeng-nasdrive/xiangchen/masp_data/20231110_ttp/video/v12044gd0000cl5c6rfog65i2eoqcqig'
inference(model_path, frame_folder)