2111# client 端 api 调用案例--------------------------------- import os import re import base64 from io import BytesIO from typing import Union import torch import requests from PIL import Image from torchvision.transforms import ToPILImage, PILToTensor from torchvision.utils import draw_bounding_boxes as _draw_bounding_boxes ######################################## # helper ######################################## def pil_to_base64(pil_img): output_buffer = BytesIO() pil_img.save(output_buffer, format="PNG") byte_data = output_buffer.getvalue() encode_img = base64.b64encode(byte_data) return str(encode_img, encoding='utf-8') def de_norm_box_xyxy(box, *, w, h): x1, y1, x2, y2 = box x1 = x1 * w x2 = x2 * w y1 = y1 * h y2 = y2 * h box = x1, y1, x2, y2 return box def draw_bounding_boxes( image, boxes, **kwargs, ): if isinstance(image, Image.Image): image = PILToTensor()(image) assert isinstance(image, torch.Tensor), "" if not isinstance(boxes, torch.Tensor): boxes = torch.as_tensor(boxes) assert isinstance(boxes, torch.Tensor) return _draw_bounding_boxes(image, boxes, **kwargs) def expand2square(pil_img, background_color=(255, 255, 255)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result ######################################## # ######################################## def query(image: Union[Image.Image, str], text: str, boxes_value: list, boxes_seq: list, server_url='http://127.0.0.1:12345/shikra'): if isinstance(image, str): image = Image.open(image) pload = { "img_base64": pil_to_base64(image), "text": text, "boxes_value": boxes_value, "boxes_seq": boxes_seq, } resp = requests.post(server_url, json=pload) if resp.status_code != 200: raise ValueError(resp.reason) ret = resp.json() return ret def postprocess(text, image): if image is None: return text, None image = expand2square(image) colors = ['#ed7d31', '#5b9bd5', '#70ad47', '#7030a0', '#c00000', '#ffff00', "olive", "brown", "cyan"] pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]') def extract_boxes(string): ret = [] for bboxes_str in pat.findall(string): bboxes = [] bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";") for bbox_str in bbox_strs: bbox = list(map(float, bbox_str.split(','))) bboxes.append(bbox) ret.append(bboxes) return ret extract_pred = extract_boxes(text) boxes_to_draw = [] color_to_draw = [] for idx, boxes in enumerate(extract_pred): color = colors[idx % len(colors)] for box in boxes: boxes_to_draw.append(de_norm_box_xyxy(box, w=image.width, h=image.height)) color_to_draw.append(color) if not boxes_to_draw: return text, None res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8) res = ToPILImage()(res) # post process text color location_text = text edit_text = list(text) bboxes_str = pat.findall(text) for idx in range(len(bboxes_str) - 1, -1, -1): color = colors[idx % len(colors)] boxes = bboxes_str[idx] span = location_text.rfind(boxes), location_text.rfind(boxes) + len(boxes) location_text = location_text[:span[0]] edit_text[span[0]:span[1]] = f'{boxes}' text = "".join(edit_text) return text, res if __name__ == '__main__': server_url = 'http://127.0.0.1:12345' + "/shikra" def example1(): image_path = os.path.join(os.path.dirname(__file__), 'assets/rec_bear.png') text = 'Can you point out a brown teddy bear with a blue bow in the image and provide the coordinates of its location?' boxes_value = [] boxes_seq = [] response = query(image_path, text, boxes_value, boxes_seq, server_url) print(response) _, image = postprocess(response['response'], image=Image.open(image_path)) print(_) if image is not None: image.show() def example2(): image_path = os.path.join(os.path.dirname(__file__), 'assets/man.jpg') text = "What is the person scared of?" boxes_value = [[148, 99, 576, 497]] boxes_seq = [[0]] response = query(image_path, text, boxes_value, boxes_seq, server_url) print(response) _, image = postprocess(response['response'], image=Image.open(image_path)) print(_) if image is not None: image.show() example1() example2()