File size: 2,395 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import sys
import torch
import os
import random
import base64
import msgpack
from io import BytesIO
import numpy as np
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2
from llava.model.builder import load_pretrained_model
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from llava.model import LlavaMistralForCausalLM
def load_model(model_path, device_map):
kwargs = {"device_map": device_map}
kwargs['torch_dtype'] = torch.float16 # Ensure correct data type
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN],
special_tokens=True
)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device_map)
return model, tokenizer
# Get the device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Load the model
model, tokenizer = load_model("./masp_094_v2", device_map={"": 0})
# Extract the vision tower
vitmodel = model.get_vision_tower()
vitmodel.to(device) # Ensure the vision tower is on the correct device
# Create a dummy input tensor for the vision tower
dummy_input = torch.randn(10, 3, 224, 224, device=device, dtype=torch.float16)
# Export the vision tower to ONNX
onnx_path = "vit_model.onnx"
with torch.no_grad():
torch.onnx.export(
vitmodel,
dummy_input,
onnx_path,
export_params=True,
opset_version=12, # Use a newer opset version for better compatibility
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
verbose=True
)
exit() |