import argparse import os import re import sys import logging from typing import Callable from fastapi import FastAPI, File, UploadFile, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import cv2 import gradio as gr import nh3 import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor from model.LISA import LISAForCausalLM from model.llava import conversation as conversation_lib from model.llava.mm_utils import tokenizer_image_token from model.segment_anything.utils.transforms import ResizeLongestSide from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) CUSTOM_GRADIO_PATH = "/gradio" app = FastAPI() FASTAPI_STATIC = os.getenv("FASTAPI_STATIC") os.makedirs(FASTAPI_STATIC, exist_ok=True) app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static") templates = Jinja2Templates(directory="templates") def get_cleaned_input(input_str): input_str = nh3.clean( input_str, tags={ "a", "abbr", "acronym", "b", "blockquote", "code", "em", "i", "li", "ol", "strong", "ul", }, attributes={ "a": {"href", "title"}, "abbr": {"title"}, "acronym": {"title"}, }, url_schemes={"http", "https", "mailto"}, link_rel=None, ) return input_str @app.get("/", response_class=HTMLResponse) async def home(request: Request): logging.info(f"Request raw: {request}.") clean_request = get_cleaned_input(str(request)) logging.info(f"clean_request: {request}.") return templates.TemplateResponse( "home.html", {"clean_request": clean_request} ) # Gradio examples = [ [ "Where can the driver see the car speed in this image? Please output segmentation mask.", "./resources/imgs/example1.jpg", ], [ "Can you segment the food that tastes spicy and hot?", "./resources/imgs/example2.jpg", ], [ "Assuming you are an autonomous driving robot, what part of the diagram would you manipulate to control the direction of travel? Please output segmentation mask and explain why.", "./resources/imgs/example1.jpg", ], [ "What can make the woman stand higher? Please output segmentation mask and explain why.", "./resources/imgs/example3.jpg", ], ] output_labels = ["Segmentation Output"] title = "LISA: Reasoning Segmentation via Large Language Model" description = """ This is the online demo of LISA. \n If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n **Note**: **Different prompts can lead to significantly varied results**. \n **Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n **Note**: Current model is **LISA-13B-llama2-v0-explanatory**, and 4-bit quantization may impair text-generation quality. \n **Usage**:
 (1) To let LISA **segment something**, input prompt like: "Can you segment xxx in this image?", "What is xxx in this image? Please output segmentation mask.";
 (2) To let LISA **output an explanation**, input prompt like: "What is xxx in this image? Please output segmentation mask and explain why.";
 (3) To obtain **solely language output**, you can input like what you should do in current multi-modal LLM (e.g., LLaVA).
Hope you can enjoy our work!
""" article = """

Preprint Paper \n

Github Repo

""" def parse_args(args_to_parse): parser = argparse.ArgumentParser(description="LISA chat") parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1") parser.add_argument("--vis_save_path", default="./vis_output", type=str) parser.add_argument( "--precision", default="fp16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for inference", ) parser.add_argument("--image_size", default=1024, type=int, help="image size") parser.add_argument("--model_max_length", default=512, type=int) parser.add_argument("--lora_r", default=8, type=int) parser.add_argument( "--vision-tower", default="openai/clip-vit-large-patch14", type=str ) parser.add_argument("--local-rank", default=0, type=int, help="node rank") parser.add_argument("--load_in_8bit", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=False) parser.add_argument("--use_mm_start_end", action="store_true", default=True) parser.add_argument( "--conv_type", default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"], ) return parser.parse_args(args_to_parse) def set_image_precision_by_args(input_image, precision): if precision == "bf16": input_image = input_image.bfloat16() elif precision == "fp16": input_image = input_image.half() else: input_image = input_image.float() return input_image def preprocess( x, pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), img_size=1024, ) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - pixel_mean) / pixel_std # Pad h, w = x.shape[-2:] padh = img_size - h padw = img_size - w x = F.pad(x, (0, padw, 0, padh)) return x def get_model(args_to_parse): os.makedirs(args_to_parse.vis_save_path, exist_ok=True) # global tokenizer, tokenizer # Create model _tokenizer = AutoTokenizer.from_pretrained( args_to_parse.version, cache_dir=None, model_max_length=args_to_parse.model_max_length, padding_side="right", use_fast=False, ) _tokenizer.pad_token = _tokenizer.unk_token args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0] torch_dtype = torch.float32 if args_to_parse.precision == "bf16": torch_dtype = torch.bfloat16 elif args_to_parse.precision == "fp16": torch_dtype = torch.half kwargs = {"torch_dtype": torch_dtype} if args_to_parse.load_in_4bit: kwargs.update( { "torch_dtype": torch.half, "load_in_4bit": True, "quantization_config": BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["visual_model"], ), } ) elif args_to_parse.load_in_8bit: kwargs.update( { "torch_dtype": torch.half, "quantization_config": BitsAndBytesConfig( llm_int8_skip_modules=["visual_model"], load_in_8bit=True, ), } ) _model = LISAForCausalLM.from_pretrained( args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs ) _model.config.eos_token_id = _tokenizer.eos_token_id _model.config.bos_token_id = _tokenizer.bos_token_id _model.config.pad_token_id = _tokenizer.pad_token_id _model.get_model().initialize_vision_modules(_model.get_model().config) vision_tower = _model.get_model().get_vision_tower() vision_tower.to(dtype=torch_dtype) if args_to_parse.precision == "bf16": _model = _model.bfloat16().cuda() elif ( args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit) ): vision_tower = _model.get_model().get_vision_tower() _model.model.vision_tower = None import deepspeed model_engine = deepspeed.init_inference( model=_model, dtype=torch.half, replace_with_kernel_inject=True, replace_method="auto", ) _model = model_engine.module _model.model.vision_tower = vision_tower.half().cuda() elif args_to_parse.precision == "fp32": _model = _model.float().cuda() vision_tower = _model.get_model().get_vision_tower() vision_tower.to(device=args_to_parse.local_rank) _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower) _transform = ResizeLongestSide(args_to_parse.image_size) _model.eval() return _model, _clip_image_processor, _tokenizer, _transform def get_inference_model_by_args(args_to_parse): model, clip_image_processor, tokenizer, transform = get_model(args_to_parse) ## to be implemented def inference(input_str, input_image): ## filter out special chars input_str = get_cleaned_input(input_str) logging.info(f"input_str type: {type(input_str)}, input_image type: {type(input_image)}.") logging.info(f"input_str: {input_str}.") ## input valid check if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1: output_str = "[Error] Invalid input: ", input_str # output_image = np.zeros((128, 128, 3)) ## error happened output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1] return output_image, output_str # Model Inference conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy() conv.messages = [] prompt = input_str prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt if args_to_parse.use_mm_start_end: replace_token = ( DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN ) prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() image_np = cv2.imread(input_image) image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) original_size_list = [image_np.shape[:2]] image_clip = ( clip_image_processor.preprocess(image_np, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() ) logging.info(f"image_clip type: {type(image_clip)}.") image_clip = set_image_precision_by_args(image_clip, args_to_parse.precision) image = transform.apply_image(image_np) resize_list = [image.shape[:2]] image = ( preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() ) logging.info(f"image_clip type: {type(image_clip)}.") image = set_image_precision_by_args(image, args_to_parse.precision) input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") input_ids = input_ids.unsqueeze(0).cuda() output_ids, pred_masks = model.evaluate( image_clip, image, input_ids, resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer, ) output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] text_output = tokenizer.decode(output_ids, skip_special_tokens=False) text_output = text_output.replace("\n", "").replace(" ", " ") text_output = text_output.split("ASSISTANT: ")[-1] logging.info(f"text_output type: {type(text_output)}, text_output: {text_output}.") save_img = None for i, pred_mask in enumerate(pred_masks): if pred_mask.shape[0] == 0: continue pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask = pred_mask > 0 save_img = image_np.copy() save_img[pred_mask] = ( image_np * 0.5 + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 )[pred_mask] output_str = "ASSITANT: " + text_output # input_str if save_img is not None: output_image = save_img # input_image else: ## no seg output output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1] return output_image, output_str return inference def get_gradio_interface( fn_inference: Callable ): return gr.Interface( fn_inference, inputs=[ gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), gr.Image(type="filepath", label="Input Image") ], outputs=[ gr.Image(type="pil", label="Segmentation Output"), gr.Textbox(lines=1, placeholder=None, label="Text Output"), ], title=title, description=description, article=article, examples=examples, allow_flagging="auto", ) if __name__ == '__main__': args = parse_args(sys.argv[1:]) inference_fn = get_inference_model_by_args(args) io = get_gradio_interface(inference_fn) app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)