Spaces:
Paused
Paused
alessandro trinca tornidor
[fix] rename app.py to main.py, use empty sys.argv to keep only defaults args
72ceb76
import argparse | |
import json | |
import logging | |
import os | |
import sys | |
from typing import Callable | |
import gradio as gr | |
import nh3 | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from utils import constants, session_logger | |
session_logger.change_logging(logging.DEBUG) | |
CUSTOM_GRADIO_PATH = "/" | |
app = FastAPI(title="lisa_app", version="1.0") | |
FASTAPI_STATIC = os.getenv("FASTAPI_STATIC") | |
os.makedirs(FASTAPI_STATIC, exist_ok=True) | |
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static") | |
templates = Jinja2Templates(directory="templates") | |
def health() -> str: | |
try: | |
logging.info("health check") | |
return json.dumps({"msg": "ok"}) | |
except Exception as e: | |
logging.error(f"exception:{e}.") | |
return json.dumps({"msg": "request failed"}) | |
def parse_args(args_to_parse): | |
parser = argparse.ArgumentParser(description="LISA chat") | |
parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1") | |
parser.add_argument("--vis_save_path", default="./vis_output", type=str) | |
parser.add_argument( | |
"--precision", | |
default="fp16", | |
type=str, | |
choices=["fp32", "bf16", "fp16"], | |
help="precision for inference", | |
) | |
parser.add_argument("--image_size", default=1024, type=int, help="image size") | |
parser.add_argument("--model_max_length", default=512, type=int) | |
parser.add_argument("--lora_r", default=8, type=int) | |
parser.add_argument( | |
"--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
) | |
parser.add_argument("--local-rank", default=0, type=int, help="node rank") | |
parser.add_argument("--load_in_8bit", action="store_true", default=False) | |
parser.add_argument("--load_in_4bit", action="store_true", default=False) | |
parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
parser.add_argument( | |
"--conv_type", | |
default="llava_v1", | |
type=str, | |
choices=["llava_v1", "llava_llama_2"], | |
) | |
return parser.parse_args(args_to_parse) | |
def get_cleaned_input(input_str): | |
logging.info(f"start cleaning of input_str: {input_str}.") | |
input_str = nh3.clean( | |
input_str, | |
tags={ | |
"a", | |
"abbr", | |
"acronym", | |
"b", | |
"blockquote", | |
"code", | |
"em", | |
"i", | |
"li", | |
"ol", | |
"strong", | |
"ul", | |
}, | |
attributes={ | |
"a": {"href", "title"}, | |
"abbr": {"title"}, | |
"acronym": {"title"}, | |
}, | |
url_schemes={"http", "https", "mailto"}, | |
link_rel=None, | |
) | |
logging.info(f"cleaned input_str: {input_str}.") | |
return input_str | |
def get_inference_model_by_args(args_to_parse): | |
logging.info(f"args_to_parse:{args_to_parse}.") | |
def inference(input_str, input_image): | |
logging.info(f"start cleaning input_str: {input_str}, type {type(input_str)}.") | |
output_str = get_cleaned_input(input_str) | |
logging.info(f"cleaned output_str: {output_str}, type {type(output_str)}.") | |
output_image = input_image | |
logging.info(f"output_image type: {type(output_image)}.") | |
return output_image, output_str | |
return inference | |
def get_gradio_interface(fn_inference: Callable): | |
return gr.Interface( | |
fn_inference, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), | |
gr.Image(type="filepath", label="Input Image") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmentation Output"), | |
gr.Textbox(lines=1, placeholder=None, label="Text Output") | |
], | |
title=constants.title, | |
description=constants.description, | |
article=constants.article, | |
examples=constants.examples, | |
allow_flagging="auto" | |
) | |
logging.info(f"sys.argv:{sys.argv}.") | |
args = parse_args([]) | |
inference_fn = get_inference_model_by_args(args) | |
io = get_gradio_interface(inference_fn) | |
app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH) | |