File size: 2,395 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import sys
import torch
import os
import random
import base64
import msgpack
from io import BytesIO
import numpy as np

from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2
from llava.model.builder import load_pretrained_model
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from llava.model import LlavaMistralForCausalLM

def load_model(model_path, device_map):
    kwargs = {"device_map": device_map}
    kwargs['torch_dtype'] = torch.float16  # Ensure correct data type
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = LlavaMistralForCausalLM.from_pretrained(
        model_path,
        low_cpu_mem_usage=True,
        **kwargs
    )
    tokenizer.add_tokens(
        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN],
        special_tokens=True
    )
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model(device_map=device_map)

    return model, tokenizer

# Get the device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load the model
model, tokenizer = load_model("./masp_094_v2", device_map={"": 0})

# Extract the vision tower
vitmodel = model.get_vision_tower()
vitmodel.to(device)  # Ensure the vision tower is on the correct device

# Create a dummy input tensor for the vision tower
dummy_input = torch.randn(10, 3, 224, 224, device=device, dtype=torch.float16)

# Export the vision tower to ONNX
onnx_path = "vit_model.onnx"
with torch.no_grad():
    torch.onnx.export(
        vitmodel,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=12,  # Use a newer opset version for better compatibility
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
        verbose=True
    )

exit()