File size: 6,659 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from llava.model import LlavaThothForCausalLM
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
from llava.conversation import conv_templates
import torch
from llava.mm_utils import tokenizer_image_token, process_images_v2, KeywordsStoppingCriteria
import numpy as np
from PIL import Image
import os


NUM_SEGMENTS = 10


def load_model(model_path, device_map):
    kwargs = {"device_map": device_map}
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = LlavaThothForCausalLM.from_pretrained(
        model_path,
        low_cpu_mem_usage=True,
        **kwargs
    )
    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model(device_map=device_map)

    image_processor = Blip2ImageTrainProcessor(
        image_size=model.config.img_size,
        is_training=False)
    model.to(torch.float16)
    return model, tokenizer, image_processor


def generate_input_ids(tokenizer):
    conv = conv_templates['thoth'].copy()
    qs = "Describe the following video in detail."
    qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
    return input_ids, conv


def generate_images(frame_folder, image_processor, model_cfg):
    images = load_frames(frame_folder)
    if len(images) > NUM_SEGMENTS:
        images = uniform_sample(images, NUM_SEGMENTS)
    return process_images_v2(images, image_processor, model_cfg)
 

def uniform_sample(frames, num_segments):
    indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
    frames = [frames[ind] for ind in indices]
    return frames

def load_frames(frames_dir):
    results = []
    image_files = [(int(os.path.splitext(img)[0]), img) for img in os.listdir(frames_dir) if img.endswith('jpg')]
    image_files = sorted(image_files, key=lambda img: img[0])

    for frame_name in image_files:
        image_path = f"{frames_dir}/{frame_name[1]}"
        image = Image.open(image_path).convert('RGB')
        results.append(image)
    return results


class MASPVisionWrapper(torch.nn.Module):

    def __init__(self, vision_tower, qformer, projector, query_tokens, frame_position_encoding, ln_vision):
        super().__init__()
        self.vision_tower = vision_tower
        self.qformer = qformer
        self.projector = projector
        self.query_tokens = query_tokens
        self.ln_vision = ln_vision
        self.frame_position_encoding = frame_position_encoding

    def forward(self, images): 
        # images: [num_frames, patches, 3, image_size, image_size]
        image_features = self.vision_tower(images.flatten(0, 1))
        image_features = self.ln_vision(image_features)
        attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) # [num_frames * num_patches, 256]
        query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) # [num_frames * num_patches, 32, 768]
        dtype_ = self.vision_tower.dtype
        image_features = self.qformer.bert(
            query_embeds= query_tokens.to(dtype_),
            encoder_hidden_states=image_features.to(dtype_),
            encoder_attention_mask=attn_mask,
            return_dict=True
        ).last_hidden_state.to(dtype_)
        frame_ids = torch.arange(images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1)
        frame_ids = frame_ids.repeat(1, images.shape[1]).flatten(0, 1) # [num_frames * num_patches]
        image_features += self.frame_position_encoding(frame_ids).unsqueeze(-2) #[num_frames, 1, 768]
        return self.projector(image_features) 


def inference(model_path, frame_folder):
    # prepare
    model, tokenizer, image_processor = load_model(model_path, device_map={"":0})
    input_ids, conv = generate_input_ids(tokenizer)
    images = generate_images(frame_folder, image_processor, model.config).to(model.device).half() # [num_frames, patches, 3, image_size, image_size]
    vision_module = MASPVisionWrapper(
        vision_tower=model.get_vision_tower(),
        qformer=model.get_qformer(),
        projector=model.get_model().mm_projector,
        query_tokens=model.get_query_tokens(),
        frame_position_encoding=model.get_frame_position_encoding(),
        ln_vision=model.get_ln_vision(),
    )
    stop_str = conv.sep if conv.sep2 is None else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    input_ids = input_ids[0].to(model.device) # [token_len]

    # infernece
    with torch.inference_mode():
        # get image feature
        image_features = vision_module(images).flatten(0, 1) # [num_frames * num_patches * num_query_token, 4096]
        # concat with text features 
        vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0]
        pre_text_token = model.get_model().embed_tokens(input_ids[:vision_token_indice])
        post_text_token = model.get_model().embed_tokens(input_ids[vision_token_indice+1:])
        inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze(0) # [1, num_token, 4096]
        
        # llm.generate
        output_ids = model.generate_from_base_class(
            inputs_embeds=inputs_embeds,
            do_sample=True,
            temperature=0.01,
            top_p=None,
            num_beams=1,
            max_new_tokens=1024,
            pad_token_id=tokenizer.eos_token_id, 
            use_cache=True,
            stopping_criteria=[stopping_criteria]       
        )
    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    output = output.strip()        
    print(output)


if __name__ == '__main__':
    model_path = '/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/llava-thothv2_mar_release_all_data'
    frame_folder = '/mnt/bn/yukunfeng-nasdrive/xiangchen/masp_data/20231110_ttp/video/v12044gd0000cl5c6rfog65i2eoqcqig'
    inference(model_path, frame_folder)