RxnIM / mllm /demo /client.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
history blame
5.22 kB
2111# client 端 api 调用案例---------------------------------
import os
import re
import base64
from io import BytesIO
from typing import Union
import torch
import requests
from PIL import Image
from torchvision.transforms import ToPILImage, PILToTensor
from torchvision.utils import draw_bounding_boxes as _draw_bounding_boxes
# helper
def pil_to_base64(pil_img):
output_buffer = BytesIO()
pil_img.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()
encode_img = base64.b64encode(byte_data)
return str(encode_img, encoding='utf-8')
def de_norm_box_xyxy(box, *, w, h):
x1, y1, x2, y2 = box
x1 = x1 * w
x2 = x2 * w
y1 = y1 * h
y2 = y2 * h
box = x1, y1, x2, y2
return box
def draw_bounding_boxes(
if isinstance(image, Image.Image):
image = PILToTensor()(image)
assert isinstance(image, torch.Tensor), ""
if not isinstance(boxes, torch.Tensor):
boxes = torch.as_tensor(boxes)
assert isinstance(boxes, torch.Tensor)
return _draw_bounding_boxes(image, boxes, **kwargs)
def expand2square(pil_img, background_color=(255, 255, 255)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def query(image: Union[Image.Image, str], text: str, boxes_value: list, boxes_seq: list, server_url=''):
if isinstance(image, str):
image = Image.open(image)
pload = {
"img_base64": pil_to_base64(image),
"text": text,
"boxes_value": boxes_value,
"boxes_seq": boxes_seq,
resp = requests.post(server_url, json=pload)
if resp.status_code != 200:
raise ValueError(resp.reason)
ret = resp.json()
return ret
def postprocess(text, image):
if image is None:
return text, None
image = expand2square(image)
colors = ['#ed7d31', '#5b9bd5', '#70ad47', '#7030a0', '#c00000', '#ffff00', "olive", "brown", "cyan"]
pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]')
def extract_boxes(string):
ret = []
for bboxes_str in pat.findall(string):
bboxes = []
bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";")
for bbox_str in bbox_strs:
bbox = list(map(float, bbox_str.split(',')))
return ret
extract_pred = extract_boxes(text)
boxes_to_draw = []
color_to_draw = []
for idx, boxes in enumerate(extract_pred):
color = colors[idx % len(colors)]
for box in boxes:
boxes_to_draw.append(de_norm_box_xyxy(box, w=image.width, h=image.height))
if not boxes_to_draw:
return text, None
res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8)
res = ToPILImage()(res)
# post process text color
location_text = text
edit_text = list(text)
bboxes_str = pat.findall(text)
for idx in range(len(bboxes_str) - 1, -1, -1):
color = colors[idx % len(colors)]
boxes = bboxes_str[idx]
span = location_text.rfind(boxes), location_text.rfind(boxes) + len(boxes)
location_text = location_text[:span[0]]
edit_text[span[0]:span[1]] = f'<span style="color:{color}; font-weight:bold;">{boxes}</span>'
text = "".join(edit_text)
return text, res
if __name__ == '__main__':
server_url = '' + "/shikra"
def example1():
image_path = os.path.join(os.path.dirname(__file__), 'assets/rec_bear.png')
text = 'Can you point out a brown teddy bear with a blue bow in the image <image> and provide the coordinates of its location?'
boxes_value = []
boxes_seq = []
response = query(image_path, text, boxes_value, boxes_seq, server_url)
_, image = postprocess(response['response'], image=Image.open(image_path))
if image is not None:
def example2():
image_path = os.path.join(os.path.dirname(__file__), 'assets/man.jpg')
text = "What is the person<boxes> scared of?"
boxes_value = [[148, 99, 576, 497]]
boxes_seq = [[0]]
response = query(image_path, text, boxes_value, boxes_seq, server_url)
_, image = postprocess(response['response'], image=Image.open(image_path))
if image is not None: