model1 / inference_deployment /convert2inf2.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
raw
history blame
6.6 kB
import sys
import torch
import os
import random
from io import BytesIO
import numpy as np
import time
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2
from llava.model.builder import load_pretrained_model
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from llava.model import LlavaMistralForCausalLM
from llava.model.multimodal_encoder.eva_vit import create_eva_vit_g
import torch_neuronx
import torch
import torch_neuronx
from llava.model import LlavaMistralForCausalLM
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import CLIPImageProcessor
from PIL import Image
import logging
from qformer_tian import BertConfig, BertModel
def select_frames(input_frames, num_segments = 10):
indices = np.linspace(start=0, stop=len(input_frames)-1, num=num_segments).astype(int)
frames = [input_frames[ind] for ind in indices]
return frames
def generate_input_ids(tokenizer):
conv = conv_templates['v1'].copy()
qs = "Describe the following video in detail."
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
return input_ids, conv
def uniform_sample(frames, num_segments):
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
save_root = './inf2_weights'
if not os.path.isdir(save_root):
os.makedirs(save_root)
EVITG_SAVE_PATH = os.path.join(save_root, 'neuron_eva_vit_batch7.pth')
LAYERNORM_SAVE_PATH = os.path.join(save_root, 'ln_state_dict.pth')
QUERYTOKEN_SAVE_PATH = os.path.join(save_root, 'query_tokens.pth')
BERT_SAVE_PATH = os.path.join(save_root, 'neuron_bert.pth')
POSITION_ENCODING_SAVE_PATH = os.path.join(save_root, 'frame_position_encoding.pth')
PROJECTOR_SAVE_PATH = os.path.join(save_root, 'projector.pth')
EMBED_TOKENS_SAVE_PATH = os.path.join(save_root, 'embed_tokens.pth')
model_path = './llava-mistral_videollava_ptv12_250k_samep_only_sopv2_mistralv2_scratch/'
disable_torch_init()
#print(model_path)
device_map={"":'cpu'}
kwargs = {"device_map": device_map}
kwargs['torch_dtype'] = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
model.config.vit_precision == 'fp32'
vision_tower = model.get_vision_tower()
vision_tower.is_loaded = False
vision_tower.load_model(device_map=device_map)
vision_tower = vision_tower.to(torch.float32)
vision_tower = vision_tower.eval()
print('vision tower hiidden size')
print(vision_tower.hidden_size)
batch_size=7
img_size=224
input_shape = (batch_size, 3, img_size, img_size)
input_data=torch.zeros(input_shape, dtype=torch.float32)
model_neuronx = torch_neuronx.trace(vision_tower, input_data, compiler_args=["--model-type=transformer"])
model_neuronx.save(EVITG_SAVE_PATH)
image_processor = Blip2ImageTrainProcessor(
image_size=model.config.img_size,
is_training=False)
input_ids, conv = generate_input_ids(tokenizer)
device = torch.device('cpu')
model = model.to(device)
conv_mode = 'v1'
NUM_SEGMENTS = 10
video_dir = './v12044gd0000cl5c6rfog65i2eoqcqig'
frames = [(int(os.path.splitext(item)[0]), os.path.join(video_dir, item)) for item in os.listdir(video_dir)]
frames = [item[1] for item in sorted(frames, key=lambda x: x[0])]
images = [Image.open(frame).convert('RGB') for frame in frames]
images = uniform_sample(images, NUM_SEGMENTS)
images = process_images_v2(images, image_processor, model.config)
#save layer norm
ln_vision = model.get_ln_vision()
ln_vision = ln_vision.eval()
ln_state_dict = ln_vision.state_dict()
torch.save(ln_state_dict, LAYERNORM_SAVE_PATH)
query_tokens = model.get_query_tokens()
#save query tokens
query_tokens_state_dict = {'query_tokens': query_tokens.data}
torch.save(query_tokens_state_dict, QUERYTOKEN_SAVE_PATH)
#save qformer
qformer = model.get_qformer()
bert_torch = qformer.bert
bert_torch = bert_torch.eval()
bert_torch = bert_torch.to(torch.float32)
vision_width = 1408
cross_attention_freq = 2
num_query_token = 32
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
bert = BertModel(encoder_config, add_pooling_layer=False)
bert.embeddings.word_embeddings = None
bert.embeddings.position_embeddings = None
for layer in bert.encoder.layer:
layer.output = None
layer.intermediate = None
bert.load_state_dict(bert_torch.state_dict())
bert = bert.eval()
input_example = (
torch.zeros(70, 32, 768, dtype=torch.float32),
torch.zeros(70, 256, 1408, dtype=torch.float32),
torch.zeros(70, 256, dtype=torch.int64)
)
neuron_bert = torch_neuronx.trace(bert, input_example)
neuron_bert.save(BERT_SAVE_PATH)
#save projector and frame position encoding
frame_position_encoding = model.get_frame_position_encoding()
projector = model.get_model().mm_projector
frame_position_encoding = frame_position_encoding.eval()
frame_position_encoding = frame_position_encoding.to(torch.float32)
projector = projector.eval()
projector = projector.to(torch.float32)
torch.save(frame_position_encoding.state_dict(), POSITION_ENCODING_SAVE_PATH)
torch.save(projector.state_dict(), PROJECTOR_SAVE_PATH)
#save embed_tokenss
embed_tokens = model.get_model().embed_tokens
embed_tokens = embed_tokens.eval()
embed_tokens = embed_tokens.to(torch.float32)
torch.save(embed_tokens.state_dict(), EMBED_TOKENS_SAVE_PATH)