RxnIM / mllm /demo /server.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame
6.59 kB
# server 端---------------------------------------
import argparse
import os
import sys
import base64
import logging
import time
from pathlib import Path
from io import BytesIO
import torch
import uvicorn
import transformers
from PIL import Image
from mmengine import Config
from transformers import BitsAndBytesConfig
from fastapi import FastAPI, Request, HTTPException
sys.path.append(str(Path(__file__).parent.parent.parent))
from mllm.dataset.process_function import PlainBoxFormatter
from mllm.dataset.builder import prepare_interactive
from mllm.models.builder.build_shikra import load_pretrained_shikra
from mllm.dataset.utils.transform import expand2square, box_xyxy_expand2square
log_level = logging.DEBUG
transformers.logging.set_verbosity(log_level)
transformers.logging.enable_default_handler()
transformers.logging.enable_explicit_format()
#########################################
# mllm model init
#########################################
parser = argparse.ArgumentParser("Shikra Server Demo")
parser.add_argument('--model_path', required=True)
parser.add_argument('--load_in_8bit', action='store_true')
parser.add_argument('--server_name', default='127.0.0.1')
parser.add_argument('--server_port', type=int, default=12345)
args = parser.parse_args()
print(args)
model_name_or_path = args.model_path
model_args = Config(dict(
type='shikra',
version='v1',
# checkpoint config
cache_dir=None,
model_name_or_path=model_name_or_path,
vision_tower=r'vit-h',
pretrain_mm_mlp_adapter=None,
# model config
mm_vision_select_layer=-2,
model_max_length=3072,
# finetune config
freeze_backbone=False,
tune_mm_mlp_adapter=False,
freeze_mm_mlp_adapter=False,
# data process config
is_multimodal=True,
sep_image_conv_front=False,
image_token_len=256,
mm_use_im_start_end=True,
target_processor=dict(
boxes=dict(type='PlainBoxFormatter'),
),
process_func_args=dict(
conv=dict(type='ShikraConvProcess'),
target=dict(type='BoxFormatProcess'),
text=dict(type='ShikraTextProcess'),
image=dict(type='ShikraImageProcessor'),
),
conv_args=dict(
conv_template='vicuna_v1.1',
transforms=dict(type='Expand2square'),
tokenize_kwargs=dict(truncation_size=None),
),
gen_kwargs_set_pad_token_id=True,
gen_kwargs_set_bos_token_id=True,
gen_kwargs_set_eos_token_id=True,
))
training_args = Config(dict(
bf16=False,
fp16=True,
device='cuda',
fsdp=None,
))
if args.load_in_8bit:
quantization_kwargs = dict(
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
)
)
else:
quantization_kwargs = dict()
model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs)
if not getattr(model, 'is_quantized', False):
model.to(dtype=torch.float16, device=torch.device('cuda'))
if not getattr(model.model.vision_tower[0], 'is_quantized', False):
model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda'))
print(
f"LLM device: {model.device}, is_quantized: {getattr(model, 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}")
print(
f"vision device: {model.model.vision_tower[0].device}, is_quantized: {getattr(model.model.vision_tower[0], 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}")
preprocessor['target'] = {'boxes': PlainBoxFormatter()}
tokenizer = preprocessor['text']
#########################################
# fast api
#########################################
app = FastAPI()
@app.post("/shikra")
async def shikra(request: Request):
try:
# receive parameters
para = await request.json()
img_base64 = para["img_base64"]
user_input = para["text"]
boxes_value = para.get('boxes_value', [])
boxes_seq = para.get('boxes_seq', [])
do_sample = para.get('do_sample', False)
max_length = para.get('max_length', 512)
top_p = para.get('top_p', 1.0)
temperature = para.get('temperature', 1.0)
# parameters preprocess
pil_image = Image.open(BytesIO(base64.b64decode(img_base64))).convert("RGB")
ds = prepare_interactive(model_args, preprocessor)
image = expand2square(pil_image)
boxes_value = [box_xyxy_expand2square(box, w=pil_image.width, h=pil_image.height) for box in boxes_value]
ds.set_image(image)
ds.append_message(role=ds.roles[0], message=user_input, boxes=boxes_value, boxes_seq=boxes_seq)
model_inputs = ds.to_model_input()
model_inputs['images'] = model_inputs['images'].to(torch.float16)
print(f"model_inputs: {model_inputs}")
# generate
if do_sample:
gen_kwargs = dict(
use_cache=True,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=max_length,
top_p=top_p,
temperature=float(temperature),
)
else:
gen_kwargs = dict(
use_cache=True,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=max_length,
)
print(gen_kwargs)
input_ids = model_inputs['input_ids']
st_time = time.time()
with torch.inference_mode():
with torch.autocast(dtype=torch.float16, device_type='cuda'):
output_ids = model.generate(**model_inputs, **gen_kwargs)
print(f"done generated in {time.time() - st_time} seconds")
input_token_len = input_ids.shape[-1]
response = tokenizer.batch_decode(output_ids[:, input_token_len:])[0]
print(f"response: {response}")
input_text = tokenizer.batch_decode(input_ids)[0]
return {
"input": input_text,
"response": response,
}
except Exception as e:
logging.exception(str(e))
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host=args.server_name, port=args.server_port, log_level="info")