|
import json |
|
from copy import deepcopy |
|
|
|
import torch |
|
|
|
import base64 |
|
from io import BytesIO |
|
from typing import Any, List, Dict |
|
|
|
from PIL import Image |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
def chat( |
|
model, |
|
image_list, |
|
msgs_list, |
|
tokenizer, |
|
vision_hidden_states=None, |
|
max_new_tokens=1024, |
|
sampling=True, |
|
max_inp_length=2048, |
|
system_prompt_list=None, |
|
**kwargs |
|
): |
|
copy_msgs_lst = [] |
|
images_list = [] |
|
tgt_sizes_list = [] |
|
for i in range(len(msgs_list)): |
|
msgs = msgs_list[i] |
|
image = image_list[i] |
|
system_prompt = system_prompt_list[i] if system_prompt_list else None |
|
if isinstance(msgs, str): |
|
msgs = json.loads(msgs) |
|
|
|
copy_msgs = deepcopy(msgs) |
|
|
|
if image is not None and isinstance(copy_msgs[0]['content'], str): |
|
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']] |
|
|
|
images = [] |
|
tgt_sizes = [] |
|
for i, msg in enumerate(copy_msgs): |
|
role = msg["role"] |
|
content = msg["content"] |
|
assert role in ["user", "assistant"] |
|
if i == 0: |
|
assert role == "user", "The role of first msg should be user" |
|
if isinstance(content, str): |
|
content = [content] |
|
|
|
cur_msgs = [] |
|
for c in content: |
|
if isinstance(c, Image.Image): |
|
image = c |
|
if model.config.slice_mode: |
|
slice_images, image_placeholder = model.get_slice_image_placeholder( |
|
image, tokenizer |
|
) |
|
cur_msgs.append(image_placeholder) |
|
for slice_image in slice_images: |
|
slice_image = model.transform(slice_image) |
|
H, W = slice_image.shape[1:] |
|
images.append(model.reshape_by_patch(slice_image)) |
|
tgt_sizes.append( |
|
torch.Tensor([H // model.config.patch_size, W // model.config.patch_size]).type(torch.int32)) |
|
else: |
|
images.append(model.transform(image)) |
|
cur_msgs.append( |
|
tokenizer.im_start |
|
+ tokenizer.unk_token * model.config.query_num |
|
+ tokenizer.im_end |
|
) |
|
elif isinstance(c, str): |
|
cur_msgs.append(c) |
|
|
|
msg['content'] = '\n'.join(cur_msgs) |
|
if tgt_sizes: |
|
tgt_sizes = torch.vstack(tgt_sizes) |
|
|
|
if system_prompt: |
|
sys_msg = {'role': 'system', 'content': system_prompt} |
|
copy_msgs = [sys_msg] + copy_msgs |
|
|
|
copy_msgs_lst.append(copy_msgs) |
|
images_list.append(images) |
|
tgt_sizes_list.append(tgt_sizes) |
|
|
|
input_ids_list = tokenizer.apply_chat_template(copy_msgs_lst, tokenize=True, add_generation_prompt=False) |
|
|
|
if sampling: |
|
generation_config = { |
|
"top_p": 0.8, |
|
"top_k": 100, |
|
"temperature": 0.7, |
|
"do_sample": True, |
|
"repetition_penalty": 1.05 |
|
} |
|
else: |
|
generation_config = { |
|
"num_beams": 3, |
|
"repetition_penalty": 1.2, |
|
} |
|
|
|
generation_config.update( |
|
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() |
|
) |
|
|
|
with torch.inference_mode(): |
|
res, vision_hidden_states = model.generate( |
|
input_id_list=input_ids_list, |
|
max_inp_length=max_inp_length, |
|
img_list=images_list, |
|
tgt_sizes=tgt_sizes_list, |
|
tokenizer=tokenizer, |
|
max_new_tokens=max_new_tokens, |
|
vision_hidden_states=vision_hidden_states, |
|
return_vision_hidden_states=True, |
|
stream=False, |
|
**generation_config |
|
) |
|
return res |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
model_name = "SwordElucidator/MiniCPM-Llama3-V-2_5-int4" |
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model.eval() |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
inputs = data.pop("inputs", data) |
|
|
|
image_list = [] |
|
msgs_list = [] |
|
|
|
for input_ in inputs: |
|
image = input_.pop("image", None) |
|
question = input_.pop("question", None) |
|
msgs = input_.pop("msgs", None) |
|
image = Image.open(BytesIO(base64.b64decode(image))) |
|
|
|
if not msgs: |
|
msgs = [{'role': 'user', 'content': question}] |
|
|
|
image_list.append(image) |
|
msgs_list.append(msgs) |
|
|
|
return chat( |
|
self.model, |
|
image_list, |
|
msgs_list, |
|
self.tokenizer, |
|
) |
|
|