Spaces:
Build error
Build error
import hydra | |
import pyrootutils | |
import torch | |
import re | |
import time | |
from omegaconf import OmegaConf | |
from flask import Flask, request | |
from typing import Optional | |
import transformers | |
from dataclasses import dataclass, field | |
import io | |
import base64 | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler | |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
from src.data.any_res import process_anyres_image | |
BOI_TOKEN = '<img>' | |
BOP_TOKEN = '<patch>' | |
EOI_TOKEN = '</img>' | |
EOP_TOKEN = '</patch>' | |
IMG_TOKEN = '<img_{:05d}>' | |
IMG_FLAG = '<image>' | |
num_img_in_tokens = 64 | |
num_img_out_tokens = 64 | |
resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', '2x3', '3x2', '2x4', '4x2'] | |
base_resolution = 448 | |
app = Flask(__name__) | |
def decode_image(encoded_image: str) -> Image: | |
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) | |
buffer = io.BytesIO(decoded_bytes) | |
image = Image.open(buffer) | |
return image | |
def encode_image(image: Image.Image, format: str = 'PNG') -> str: | |
with io.BytesIO() as buffer: | |
image.save(buffer, format=format) | |
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return encoded_image | |
class Arguments: | |
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"}) | |
tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
llm: Optional[str] = field(default=None, metadata={"help": "config path of llm"}) | |
visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"}) | |
sd_adapter: Optional[str] = field(default=None, metadata={"help": "config path of sd adapter"}) | |
agent: Optional[str] = field(default=None, metadata={"help": "config path of agent model"}) | |
diffusion_path: Optional[str] = field(default=None, metadata={"help": "diffusion model path"}) | |
has_bbox: Optional[bool] = field(default=False, metadata={"help": "visualize the box"}) | |
port: Optional[str] = field(default=80, metadata={"help": "network port"}) | |
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) | |
vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"}) | |
dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"}) | |
multi_resolution: Optional[bool] = field(default=False, metadata={"help": "multi resolution"}) | |
parser = transformers.HfArgumentParser(Arguments) | |
args, = parser.parse_args_into_dataclasses() | |
def extract_box(output_str): | |
boxes = re.findall('(.*?)<box_end>', output_str) | |
if len(boxes) >0: | |
bboxes = [[int(num) for num in re.findall('<loc-(\d+)>', box)] for box in boxes] | |
else: | |
bboxes = None | |
return bboxes | |
def visualize_bbox(image, bboxes): | |
img_width, img_height = image.size | |
image = np.array(image) | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
for bbox in bboxes: | |
x_center, y_center, box_width, box_height = bbox | |
x_center = x_center / 224 * img_width | |
y_center = y_center / 224 * img_height | |
box_width = box_width /224 * img_width | |
box_height = box_height / 224 * img_height | |
x1 = int(x_center - box_width / 2) | |
y1 = int(y_center - box_height / 2) | |
x2 = int(x_center + box_width / 2) | |
y2 = int(y_center + box_height / 2) | |
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 4) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(image) | |
return image | |
class LLMService: | |
def __init__(self, args) -> None: | |
self.llm_device = args.llm_device | |
self.vit_sd_device = args.vit_sd_device | |
dtype = args.dtype | |
if dtype == 'fp16': | |
self.dtype = torch.float16 | |
elif dtype == 'bf16': | |
self.dtype = torch.bfloat16 | |
else: | |
raise ValueError | |
image_transform_cfg = OmegaConf.load(args.image_transform) | |
self.image_transform = hydra.utils.instantiate(image_transform_cfg) | |
tokenizer_cfg = OmegaConf.load(args.tokenizer) | |
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg) | |
visual_encoder_cfg = OmegaConf.load(args.visual_encoder) | |
self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) | |
self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype) | |
print('Init visual encoder done') | |
llm_cfg = OmegaConf.load(args.llm) | |
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype) | |
print('Init llm done.') | |
agent_cfg = OmegaConf.load(args.agent) | |
self.agent = hydra.utils.instantiate(agent_cfg, llm=llm) | |
self.agent.eval().to(self.llm_device, dtype=self.dtype) | |
print('Init agent mdoel Done') | |
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler") | |
vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype) | |
unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(dtype=self.dtype) | |
sd_adapter_cfg = OmegaConf.load(args.sd_adapter) | |
self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(dtype=self.dtype) | |
self.sd_adapter.init_pipe(vae=vae, | |
scheduler=noise_scheduler, | |
visual_encoder=self.visual_encoder.to("cpu"), | |
image_transform=self.image_transform, | |
discrete_model=None, | |
dtype=self.dtype, | |
device="cpu") | |
print('Init sd adapter pipe done.') | |
self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype) | |
self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] | |
self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] | |
self.bop_token_id = self.tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0] | |
self.eop_token_id = self.tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0] | |
self.multi_resolution = args.multi_resolution | |
if self.multi_resolution: | |
self.base_resolution = base_resolution | |
grid_pinpoints = [] | |
for scale in resolution_grids: | |
s1, s2 = scale.split('x') | |
grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution]) | |
self.grid_pinpoints = grid_pinpoints | |
service = LLMService(args) | |
def generate(): | |
with torch.no_grad(): | |
request_info = request.get_json() | |
text_list = request_info['text'].split(IMG_FLAG) | |
image_list = request_info['images'] | |
max_new_tokens = request_info.get('max_new_tokens', 256) | |
top_p = 0.5 | |
force_boi = request_info.get('force_boi', False) | |
force_bbox = request_info.get('force_bbox', False) | |
assert len(text_list) == len(image_list) + 1 | |
image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN | |
input_images = [] | |
if len(image_list) > 0: | |
image_tensor_list = [] | |
embeds_cmp_mask = [] | |
embeds_gen_mask = [] | |
if service.multi_resolution: | |
patch_pos = [] | |
image_patch_length = [] | |
image_size_list = [] | |
for idx, image_item in enumerate(image_list): | |
if isinstance(image_item, str): | |
image = decode_image(image_item) | |
print('after decode image size:', image.size) | |
input_images.append(image) | |
if service.multi_resolution: | |
image_size_list.append(image.size) | |
print('image size:', image.size) | |
image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform, service.grid_pinpoints, service.base_resolution) | |
image_tensor_list.append(image_tensor) | |
patch_pos.append(patch_pos_tensor) | |
image_patch_length.append(image_tensor.shape[0]) | |
print('image_patch_length', image_patch_length) | |
embeds_cmp_mask.extend([True]*image_tensor.shape[0]) | |
embeds_gen_mask.extend([False]*image_tensor.shape[0]) | |
else: | |
image_tensor = service.image_transform(image) | |
image_tensor_list.append(image_tensor) | |
embeds_cmp_mask.append(True) | |
embeds_gen_mask.append(False) | |
else: | |
raise ValueError | |
if service.multi_resolution: | |
pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) | |
patch_position = torch.cat(patch_pos, dim=0) | |
image_tokens_list = [] | |
for patch_length in image_patch_length: | |
image_tokens = '' | |
for _ in range(patch_length-1): | |
image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN | |
image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN | |
image_tokens_list.append(image_tokens) | |
else: | |
pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) | |
image_embeds = service.visual_encoder(pixel_values) | |
image_embeds = image_embeds.to(service.llm_device) | |
embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device) | |
embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device) | |
else: | |
image_embeds = None | |
patch_position = 0 | |
embeds_cmp_mask = None | |
embeds_gen_mask = None | |
if service.multi_resolution: | |
input_text = '' | |
for i, c in enumerate(text_list[:-1]): | |
input_text += c + image_tokens_list[i] | |
input_text += text_list[-1] | |
else: | |
input_text = image_tokens.join(text_list) | |
if force_boi: | |
input_text = input_text + BOI_TOKEN | |
if force_bbox: | |
input_text = input_text + '[[ <box_start>' | |
print('input_text:', input_text) | |
input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) | |
input_ids = [service.tokenizer.bos_token_id] + input_ids | |
input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long) | |
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
if service.multi_resolution: | |
boi_indices = torch.where(torch.logical_or(input_ids == service.boi_token_id, input_ids == service.bop_token_id))[0].tolist() | |
eoi_indices = torch.where(torch.logical_or(input_ids == service.eoi_token_id, input_ids == service.eop_token_id))[0].tolist() | |
else: | |
boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist() | |
eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist() | |
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): | |
ids_cmp_mask[boi_idx + 1:eoi_idx] = True | |
input_ids = input_ids.unsqueeze(0) | |
ids_cmp_mask = ids_cmp_mask.unsqueeze(0) | |
ids_gen_mask = ids_gen_mask.unsqueeze(0) | |
error_msg = [] | |
if service.multi_resolution: | |
output = service.agent.generate( | |
tokenizer=service.tokenizer, | |
input_ids=input_ids, | |
image_embeds=image_embeds, | |
patch_positions=patch_position, | |
embeds_cmp_mask=embeds_cmp_mask, | |
ids_cmp_mask=ids_cmp_mask, | |
num_img_gen_tokens=num_img_out_tokens, | |
max_new_tokens=max_new_tokens, | |
dtype=service.dtype, | |
device=service.llm_device, | |
top_p=top_p, | |
) | |
else: | |
output = service.agent.generate( | |
tokenizer=service.tokenizer, | |
input_ids=input_ids, | |
image_embeds=image_embeds, | |
embeds_cmp_mask=embeds_cmp_mask, | |
ids_cmp_mask=ids_cmp_mask, | |
num_img_gen_tokens=num_img_out_tokens, | |
max_new_tokens=max_new_tokens, | |
dtype=service.dtype, | |
device=service.llm_device, | |
top_p=top_p, | |
) | |
gen_imgs_base64_list = [] | |
generated_text = output['text'] | |
generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '') | |
if output['has_img_output']: | |
print('loading visual encoder and llm to CPU, and sd to GPU') | |
a = time.time() | |
service.agent = service.agent.to("cpu") | |
service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype) | |
print("Loading finished: ", time.time() - a) | |
img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype) | |
for img_idx in range(output['num_gen_imgs']): | |
img_feat = img_gen_feat[img_idx:img_idx + 1] | |
generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0] | |
image_base64 = encode_image(generated_image) | |
gen_imgs_base64_list.append(image_base64) | |
print('loading visual encoder and llm to GPU, and sd to CPU') | |
a = time.time() | |
service.sd_adapter = service.sd_adapter.to("cpu") | |
service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype) | |
service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype) | |
print("Loading finished: ", time.time() - a) | |
if args.has_bbox: | |
bboxes = extract_box(generated_text) | |
if bboxes is not None and len(input_images) > 0: | |
image_viz = visualize_bbox(input_images[0], bboxes) | |
image_base64 = encode_image(image_viz) | |
gen_imgs_base64_list.append(image_base64) | |
generated_text = re.sub(r'\[\[ <box_start>.*?<box_end>.*?\]\]', 'the green bounding box', generated_text) | |
generated_text += IMG_FLAG | |
print(input_text + generated_text) | |
return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg} | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=args.port) | |