lisa-on-cuda / main.py
alessandro trinca tornidor
[fix] rename app.py to main.py, use empty sys.argv to keep only defaults args
72ceb76
raw
history blame
4.36 kB
import argparse
import json
import logging
import os
import sys
from typing import Callable
import gradio as gr
import nh3
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from utils import constants, session_logger
session_logger.change_logging(logging.DEBUG)
CUSTOM_GRADIO_PATH = "/"
app = FastAPI(title="lisa_app", version="1.0")
FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
os.makedirs(FASTAPI_STATIC, exist_ok=True)
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/health")
@session_logger.set_uuid_logging
def health() -> str:
try:
logging.info("health check")
return json.dumps({"msg": "ok"})
except Exception as e:
logging.error(f"exception:{e}.")
return json.dumps({"msg": "request failed"})
@session_logger.set_uuid_logging
def parse_args(args_to_parse):
parser = argparse.ArgumentParser(description="LISA chat")
parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
parser.add_argument(
"--precision",
default="fp16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--image_size", default=1024, type=int, help="image size")
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
parser.add_argument(
"--conv_type",
default="llava_v1",
type=str,
choices=["llava_v1", "llava_llama_2"],
)
return parser.parse_args(args_to_parse)
@session_logger.set_uuid_logging
def get_cleaned_input(input_str):
logging.info(f"start cleaning of input_str: {input_str}.")
input_str = nh3.clean(
input_str,
tags={
"a",
"abbr",
"acronym",
"b",
"blockquote",
"code",
"em",
"i",
"li",
"ol",
"strong",
"ul",
},
attributes={
"a": {"href", "title"},
"abbr": {"title"},
"acronym": {"title"},
},
url_schemes={"http", "https", "mailto"},
link_rel=None,
)
logging.info(f"cleaned input_str: {input_str}.")
return input_str
@session_logger.set_uuid_logging
def get_inference_model_by_args(args_to_parse):
logging.info(f"args_to_parse:{args_to_parse}.")
@session_logger.set_uuid_logging
def inference(input_str, input_image):
logging.info(f"start cleaning input_str: {input_str}, type {type(input_str)}.")
output_str = get_cleaned_input(input_str)
logging.info(f"cleaned output_str: {output_str}, type {type(output_str)}.")
output_image = input_image
logging.info(f"output_image type: {type(output_image)}.")
return output_image, output_str
return inference
@session_logger.set_uuid_logging
def get_gradio_interface(fn_inference: Callable):
return gr.Interface(
fn_inference,
inputs=[
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
gr.Image(type="filepath", label="Input Image")
],
outputs=[
gr.Image(type="pil", label="Segmentation Output"),
gr.Textbox(lines=1, placeholder=None, label="Text Output")
],
title=constants.title,
description=constants.description,
article=constants.article,
examples=constants.examples,
allow_flagging="auto"
)
logging.info(f"sys.argv:{sys.argv}.")
args = parse_args([])
inference_fn = get_inference_model_by_args(args)
io = get_gradio_interface(inference_fn)
app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)