File size: 6,600 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import sys
import torch
import os
import random
from io import BytesIO
import numpy as np
import time
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2
from llava.model.builder import load_pretrained_model
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from llava.model import LlavaMistralForCausalLM
from llava.model.multimodal_encoder.eva_vit import create_eva_vit_g
import torch_neuronx
import torch
import torch_neuronx
from llava.model import LlavaMistralForCausalLM
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


from transformers import CLIPImageProcessor
from PIL import Image
import logging
from qformer_tian import BertConfig, BertModel


def select_frames(input_frames, num_segments = 10):

    indices = np.linspace(start=0, stop=len(input_frames)-1, num=num_segments).astype(int)

    frames = [input_frames[ind] for ind in indices]

    return frames

    
def generate_input_ids(tokenizer):
    conv = conv_templates['v1'].copy()
    qs = "Describe the following video in detail."
    qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
    return input_ids, conv

def uniform_sample(frames, num_segments):
    indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
    frames = [frames[ind] for ind in indices]
    return frames

save_root = './inf2_weights'
if not os.path.isdir(save_root):
    os.makedirs(save_root)

EVITG_SAVE_PATH = os.path.join(save_root, 'neuron_eva_vit_batch7.pth')
LAYERNORM_SAVE_PATH = os.path.join(save_root, 'ln_state_dict.pth')
QUERYTOKEN_SAVE_PATH = os.path.join(save_root, 'query_tokens.pth')
BERT_SAVE_PATH = os.path.join(save_root, 'neuron_bert.pth')
POSITION_ENCODING_SAVE_PATH = os.path.join(save_root, 'frame_position_encoding.pth')
PROJECTOR_SAVE_PATH = os.path.join(save_root, 'projector.pth')
EMBED_TOKENS_SAVE_PATH = os.path.join(save_root, 'embed_tokens.pth')


model_path = './llava-mistral_videollava_ptv12_250k_samep_only_sopv2_mistralv2_scratch/'
disable_torch_init()
#print(model_path)
device_map={"":'cpu'}
kwargs = {"device_map": device_map}
kwargs['torch_dtype'] = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    **kwargs
)
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

model.config.vit_precision == 'fp32'
vision_tower = model.get_vision_tower()
vision_tower.is_loaded = False
vision_tower.load_model(device_map=device_map)
vision_tower = vision_tower.to(torch.float32)

vision_tower = vision_tower.eval()
print('vision tower hiidden size')
print(vision_tower.hidden_size)

batch_size=7
img_size=224
input_shape = (batch_size, 3, img_size, img_size)
input_data=torch.zeros(input_shape, dtype=torch.float32)
model_neuronx = torch_neuronx.trace(vision_tower, input_data, compiler_args=["--model-type=transformer"])
model_neuronx.save(EVITG_SAVE_PATH)

image_processor = Blip2ImageTrainProcessor(
    image_size=model.config.img_size,
    is_training=False)

input_ids, conv = generate_input_ids(tokenizer)
device = torch.device('cpu')
model = model.to(device)
conv_mode = 'v1'
NUM_SEGMENTS = 10

video_dir = './v12044gd0000cl5c6rfog65i2eoqcqig'
frames = [(int(os.path.splitext(item)[0]), os.path.join(video_dir, item)) for item in os.listdir(video_dir)]
frames = [item[1] for item in sorted(frames, key=lambda x: x[0])]
images = [Image.open(frame).convert('RGB') for frame in frames]
images = uniform_sample(images, NUM_SEGMENTS)
images =  process_images_v2(images, image_processor, model.config)

#save layer norm
ln_vision = model.get_ln_vision()
ln_vision = ln_vision.eval()
ln_state_dict = ln_vision.state_dict()
torch.save(ln_state_dict, LAYERNORM_SAVE_PATH)


query_tokens = model.get_query_tokens()
#save query tokens
query_tokens_state_dict = {'query_tokens': query_tokens.data}
torch.save(query_tokens_state_dict, QUERYTOKEN_SAVE_PATH)

#save qformer
qformer = model.get_qformer()
bert_torch = qformer.bert
bert_torch = bert_torch.eval()
bert_torch = bert_torch.to(torch.float32)


vision_width = 1408
cross_attention_freq = 2
num_query_token = 32
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
bert = BertModel(encoder_config, add_pooling_layer=False)
bert.embeddings.word_embeddings = None
bert.embeddings.position_embeddings = None

for layer in bert.encoder.layer:
    layer.output = None
    layer.intermediate = None


bert.load_state_dict(bert_torch.state_dict())
bert = bert.eval()

input_example = (
        torch.zeros(70, 32, 768, dtype=torch.float32),
        torch.zeros(70, 256, 1408, dtype=torch.float32),
        torch.zeros(70, 256, dtype=torch.int64)
)
neuron_bert = torch_neuronx.trace(bert, input_example)
neuron_bert.save(BERT_SAVE_PATH)

#save projector and frame position encoding
frame_position_encoding = model.get_frame_position_encoding()
projector = model.get_model().mm_projector

frame_position_encoding = frame_position_encoding.eval()
frame_position_encoding = frame_position_encoding.to(torch.float32)

projector = projector.eval()
projector = projector.to(torch.float32)

torch.save(frame_position_encoding.state_dict(), POSITION_ENCODING_SAVE_PATH)
torch.save(projector.state_dict(), PROJECTOR_SAVE_PATH)

#save embed_tokenss
embed_tokens = model.get_model().embed_tokens
embed_tokens = embed_tokens.eval()
embed_tokens = embed_tokens.to(torch.float32)
torch.save(embed_tokens.state_dict(), EMBED_TOKENS_SAVE_PATH)