Spaces:
Runtime error
Runtime error
""" | |
A model worker executes the model. | |
""" | |
import argparse | |
import asyncio | |
import json | |
import time | |
import threading | |
import uuid | |
from fastapi import FastAPI, Request, BackgroundTasks | |
from fastapi.responses import StreamingResponse | |
import requests | |
import torch | |
import uvicorn | |
from functools import partial | |
from minigemini.constants import WORKER_HEART_BEAT_INTERVAL | |
from minigemini.utils import (build_logger, server_error_msg, | |
pretty_print_semaphore) | |
from minigemini.model.builder import load_pretrained_model | |
from minigemini.mm_utils import process_images, load_image_from_base64, tokenizer_image_token | |
from minigemini.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
try: | |
from diffusers import StableDiffusionXLPipeline | |
except: | |
print('please install diffusers==0.26.3') | |
try: | |
from paddleocr import PaddleOCR | |
except: | |
print('please install paddleocr following https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/README_en.md') | |
import io | |
import base64 | |
GB = 1 << 30 | |
worker_id = str(uuid.uuid4())[:6] | |
logger = build_logger("model_worker", f"model_worker_{worker_id}.log") | |
global_counter = 0 | |
model_semaphore = None | |
def heart_beat_worker(controller): | |
while True: | |
time.sleep(WORKER_HEART_BEAT_INTERVAL) | |
controller.send_heart_beat() | |
class ModelWorker: | |
def __init__(self, controller_addr, worker_addr, | |
worker_id, no_register, | |
model_path, model_base, model_name, | |
load_8bit, load_4bit, device, use_flash_attn=False): | |
self.controller_addr = controller_addr | |
self.worker_addr = worker_addr | |
self.worker_id = worker_id | |
if model_path.endswith("/"): | |
model_path = model_path[:-1] | |
if model_name is None: | |
model_paths = model_path.split("/") | |
if model_paths[-1].startswith('checkpoint-'): | |
self.model_name = model_paths[-2] + "_" + model_paths[-1] | |
else: | |
self.model_name = model_paths[-1] | |
else: | |
self.model_name = model_name | |
self.device = device | |
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") | |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( | |
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn) | |
# self.is_multimodal = 'llava' in self.model_name.lower() | |
self.is_multimodal = True | |
if hasattr(self.model.config, 'image_size_aux'): | |
if not hasattr(self.image_processor, 'image_size_raw'): | |
self.image_processor.image_size_raw = self.image_processor.crop_size.copy() | |
self.image_processor.crop_size['height'] = self.model.config.image_size_aux | |
self.image_processor.crop_size['width'] = self.model.config.image_size_aux | |
self.image_processor.size['shortest_edge'] = self.model.config.image_size_aux | |
# ocr model | |
self.ocr_model = PaddleOCR(use_angle_cls=True, use_gpu=True, lang="ch") | |
# diffusion model | |
max_gpu_index = torch.cuda.device_count() - 1 | |
device_last = torch.device(f'cuda:{max_gpu_index}') | |
print(torch.cuda.device_count(), '++++++', device_last) | |
self.pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16" | |
).to(device=device_last) | |
if not no_register: | |
self.register_to_controller() | |
self.heart_beat_thread = threading.Thread( | |
target=heart_beat_worker, args=(self,)) | |
self.heart_beat_thread.start() | |
def register_to_controller(self): | |
logger.info("Register to controller") | |
url = self.controller_addr + "/register_worker" | |
data = { | |
"worker_name": self.worker_addr, | |
"check_heart_beat": True, | |
"worker_status": self.get_status() | |
} | |
r = requests.post(url, json=data) | |
assert r.status_code == 200 | |
def send_heart_beat(self): | |
logger.info(f"Send heart beat. Models: {[self.model_name]}. " | |
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " | |
f"global_counter: {global_counter}") | |
url = self.controller_addr + "/receive_heart_beat" | |
while True: | |
try: | |
ret = requests.post(url, json={ | |
"worker_name": self.worker_addr, | |
"queue_length": self.get_queue_length()}, timeout=30) | |
exist = ret.json()["exist"] | |
break | |
except requests.exceptions.RequestException as e: | |
logger.error(f"heart beat error: {e}") | |
time.sleep(5) | |
if not exist: | |
self.register_to_controller() | |
def get_queue_length(self): | |
if model_semaphore is None: | |
return 0 | |
else: | |
return args.limit_model_concurrency - model_semaphore._value + (len( | |
model_semaphore._waiters) if model_semaphore._waiters is not None else 0) | |
def get_status(self): | |
return { | |
"model_names": [self.model_name], | |
"speed": 1, | |
"queue_length": self.get_queue_length(), | |
} | |
def add_content(self, prompt, new_content): | |
if '[INST]' in prompt: | |
split_index = prompt.rfind(' [/INST]') | |
elif '<|im_end|>' in prompt: | |
split_index = prompt.rfind('<|im_end|>') | |
else: | |
split_index = prompt.rfind('###Assistant:') | |
left_prompt = prompt[:split_index] | |
right_prompt = prompt[split_index:] | |
prompt = left_prompt + new_content + right_prompt | |
return prompt | |
def generate_stream(self, params): | |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor | |
prompt = params["prompt"] | |
ori_prompt = prompt | |
images = params.get("images", None) | |
gen_image = params.get("gen_image", False) | |
use_ocr = params.get("use_ocr", False) | |
num_image_tokens = 0 | |
if gen_image: | |
prompt = self.add_content(prompt, ' <GEN>') | |
print(prompt) | |
if images is not None and len(images) > 0 and self.is_multimodal: # len(images) = 1 | |
if len(images) > 0: | |
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): | |
raise ValueError("Number of images does not match number of <image> tokens in prompt") | |
images = [load_image_from_base64(image) for image in images] | |
# add OCR tokens | |
if use_ocr: | |
str_in_image = '' | |
for image in images: | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format=image.format) | |
img_byte_arr = img_byte_arr.getvalue() | |
result = self.ocr_model.ocr(img_byte_arr, cls=True) | |
if result[0] is not None: | |
result = [res[1][0] for res in result[0] if res[1][1] > 0.1] | |
if len(result) > 0: | |
str_in_image += ', '.join(result) | |
# print('OCR Token: ' + str_in_image) | |
if len(str_in_image) > 0: | |
prompt = self.add_content(prompt, '\nReference OCR Token: ' + str_in_image + '\n') | |
image_tensor = process_images(images, image_processor, model.config) | |
image_grid = getattr(model.config, 'image_grid', 1) | |
if hasattr(model.config, 'image_size_aux'): | |
raw_shape = [image_processor.image_size_raw['height'] * image_grid, | |
image_processor.image_size_raw['width'] * image_grid] | |
image_tensor_aux = image_tensor | |
image_tensor = torch.nn.functional.interpolate(image_tensor, | |
size=raw_shape, | |
mode='bilinear', | |
align_corners=False) # # torch.Size([1, 3, 336, 336]) | |
else: | |
image_tensor_aux = [] | |
if image_grid >= 2: | |
raw_image = image_tensor.reshape(3, | |
image_grid, | |
image_processor.image_size_raw['height'], | |
image_grid, | |
image_processor.image_size_raw['width']) | |
raw_image = raw_image.permute(1, 3, 0, 2, 4) | |
raw_image = raw_image.reshape(-1, 3, | |
image_processor.image_size_raw['height'], | |
image_processor.image_size_raw['width']) | |
if getattr(model.config, 'image_global', False): | |
global_image = image_tensor | |
if len(global_image.shape) == 3: | |
global_image = global_image[None] | |
global_image = torch.nn.functional.interpolate(global_image, | |
size=[image_processor.image_size_raw['height'], | |
image_processor.image_size_raw['width']], | |
mode='bilinear', | |
align_corners=False) | |
# [image_crops, image_global] | |
raw_image = torch.cat([raw_image, global_image], dim=0) | |
image_tensor = raw_image.contiguous() | |
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16).unsqueeze(0) | |
image_tensor_aux = image_tensor_aux.to(self.model.device, dtype=torch.float16) | |
replace_token = DEFAULT_IMAGE_TOKEN | |
if getattr(self.model.config, 'mm_use_im_start_end', False): | |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN | |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches | |
else: | |
image_tensor = None | |
image_args = {"images": image_tensor, "images_aux": image_tensor_aux} | |
else: | |
image_tensor = None | |
image_args = {} | |
temperature = float(params.get("temperature", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_context_length = getattr(model.config, 'max_position_embeddings', 2048) | |
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) | |
stop_str = params.get("stop", None) | |
do_sample = True if temperature > 0.001 else False | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30) | |
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) | |
if max_new_tokens < 1: | |
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" | |
return | |
thread = Thread(target=model.generate, kwargs=dict( | |
inputs=input_ids, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_p=top_p, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
use_cache=True, | |
**image_args | |
)) | |
thread.start() | |
generated_text = ori_prompt | |
for new_text in streamer: | |
generated_text += new_text | |
if generated_text.endswith(stop_str): | |
generated_text = generated_text[:-len(stop_str)] | |
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" | |
torch.cuda.empty_cache() | |
if gen_image and "<h>" in generated_text and "</h>" in generated_text: | |
# common_neg_prompt = "blur, lowres, bad anatomy, bad hands, cropped, worst quality" | |
common_neg_prompt = "out of frame, lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" | |
prompt = generated_text.split("<h>")[1].split("</h>")[0] | |
# yield json.dumps({"text": prompt, "error_code": 0}).encode() + b"\0" | |
output_img = self.pipe(prompt, negative_prompt=common_neg_prompt).images[0] | |
buffered = io.BytesIO() | |
output_img.save(buffered, format='JPEG') | |
img_b64_str = base64.b64encode(buffered.getvalue()).decode() | |
torch.cuda.empty_cache() | |
generated_text = generated_text.split("<h>")[0] + '\n' + 'Prompt: ' + prompt + '\n' | |
yield json.dumps({"text": generated_text, "image": img_b64_str, "error_code": 0}).encode() + b"\0" | |
def generate_stream_gate(self, params): | |
try: | |
for x in self.generate_stream(params): | |
yield x | |
except ValueError as e: | |
print("Caught ValueError:", e) | |
ret = { | |
"text": server_error_msg, | |
"error_code": 1, | |
} | |
yield json.dumps(ret).encode() + b"\0" | |
except torch.cuda.CudaError as e: | |
print("Caught torch.cuda.CudaError:", e) | |
ret = { | |
"text": server_error_msg, | |
"error_code": 1, | |
} | |
yield json.dumps(ret).encode() + b"\0" | |
except Exception as e: | |
print("Caught Unknown Error", e) | |
ret = { | |
"text": server_error_msg, | |
"error_code": 1, | |
} | |
yield json.dumps(ret).encode() + b"\0" | |
app = FastAPI() | |
def release_model_semaphore(fn=None): | |
model_semaphore.release() | |
if fn is not None: | |
fn() | |
async def generate_stream(request: Request): | |
global model_semaphore, global_counter | |
global_counter += 1 | |
params = await request.json() | |
if model_semaphore is None: | |
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) | |
await model_semaphore.acquire() | |
worker.send_heart_beat() | |
generator = worker.generate_stream_gate(params) | |
background_tasks = BackgroundTasks() | |
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) | |
return StreamingResponse(generator, background=background_tasks) | |
async def get_status(request: Request): | |
return worker.get_status() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="localhost") | |
parser.add_argument("--port", type=int, default=21002) | |
parser.add_argument("--worker-address", type=str, | |
default="http://localhost:21002") | |
parser.add_argument("--controller-address", type=str, | |
default="http://localhost:21001") | |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m") | |
parser.add_argument("--model-base", type=str, default=None) | |
parser.add_argument("--model-name", type=str) | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") | |
parser.add_argument("--limit-model-concurrency", type=int, default=5) | |
parser.add_argument("--stream-interval", type=int, default=1) | |
parser.add_argument("--no-register", action="store_true") | |
parser.add_argument("--load-8bit", action="store_true") | |
parser.add_argument("--load-4bit", action="store_true") | |
parser.add_argument("--use-flash-attn", action="store_true") | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
if args.multi_modal: | |
logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") | |
worker = ModelWorker(args.controller_address, | |
args.worker_address, | |
worker_id, | |
args.no_register, | |
args.model_path, | |
args.model_base, | |
args.model_name, | |
args.load_8bit, | |
args.load_4bit, | |
args.device, | |
use_flash_attn=args.use_flash_attn) | |
uvicorn.run(app, host=args.host, port=args.port, log_level="info") |