File size: 4,363 Bytes
a84a5a1
326115c
 
a84a5a1
 
719ecfd
 
 
 
 
a84a5a1
 
e5c9ee0
719ecfd
acec8bf
326115c
719ecfd
326115c
c5fe4a2
 
3bd20e4
a84a5a1
 
 
 
 
3bd20e4
c5fe4a2
 
 
 
 
 
 
 
 
3bd20e4
 
719ecfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5fe4a2
a84a5a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719ecfd
 
 
 
 
a84a5a1
 
 
 
 
 
 
 
 
719ecfd
a84a5a1
 
719ecfd
 
a84a5a1
719ecfd
 
 
 
 
 
 
 
a84a5a1
 
72ceb76
 
a84a5a1
 
f623930
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import json
import logging
import os
import sys
from typing import Callable

import gradio as gr
import nh3
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates

from utils import constants, session_logger


session_logger.change_logging(logging.DEBUG)

CUSTOM_GRADIO_PATH = "/"
app = FastAPI(title="lisa_app", version="1.0")

FASTAPI_STATIC = os.getenv("FASTAPI_STATIC")
os.makedirs(FASTAPI_STATIC, exist_ok=True)
app.mount("/static", StaticFiles(directory=FASTAPI_STATIC), name="static")
templates = Jinja2Templates(directory="templates")


@app.get("/health")
@session_logger.set_uuid_logging
def health() -> str:
    try:
        logging.info("health check")
        return json.dumps({"msg": "ok"})
    except Exception as e:
        logging.error(f"exception:{e}.")
        return json.dumps({"msg": "request failed"})


@session_logger.set_uuid_logging
def parse_args(args_to_parse):
    parser = argparse.ArgumentParser(description="LISA chat")
    parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1")
    parser.add_argument("--vis_save_path", default="./vis_output", type=str)
    parser.add_argument(
        "--precision",
        default="fp16",
        type=str,
        choices=["fp32", "bf16", "fp16"],
        help="precision for inference",
    )
    parser.add_argument("--image_size", default=1024, type=int, help="image size")
    parser.add_argument("--model_max_length", default=512, type=int)
    parser.add_argument("--lora_r", default=8, type=int)
    parser.add_argument(
        "--vision-tower", default="openai/clip-vit-large-patch14", type=str
    )
    parser.add_argument("--local-rank", default=0, type=int, help="node rank")
    parser.add_argument("--load_in_8bit", action="store_true", default=False)
    parser.add_argument("--load_in_4bit", action="store_true", default=False)
    parser.add_argument("--use_mm_start_end", action="store_true", default=True)
    parser.add_argument(
        "--conv_type",
        default="llava_v1",
        type=str,
        choices=["llava_v1", "llava_llama_2"],
    )
    return parser.parse_args(args_to_parse)


@session_logger.set_uuid_logging
def get_cleaned_input(input_str):
    logging.info(f"start cleaning of input_str: {input_str}.")
    input_str = nh3.clean(
        input_str,
        tags={
            "a",
            "abbr",
            "acronym",
            "b",
            "blockquote",
            "code",
            "em",
            "i",
            "li",
            "ol",
            "strong",
            "ul",
        },
        attributes={
            "a": {"href", "title"},
            "abbr": {"title"},
            "acronym": {"title"},
        },
        url_schemes={"http", "https", "mailto"},
        link_rel=None,
    )
    logging.info(f"cleaned input_str: {input_str}.")
    return input_str


@session_logger.set_uuid_logging
def get_inference_model_by_args(args_to_parse):
    logging.info(f"args_to_parse:{args_to_parse}.")

    @session_logger.set_uuid_logging
    def inference(input_str, input_image):
        logging.info(f"start cleaning input_str: {input_str}, type {type(input_str)}.")
        output_str = get_cleaned_input(input_str)
        logging.info(f"cleaned output_str: {output_str}, type {type(output_str)}.")
        output_image = input_image
        logging.info(f"output_image type: {type(output_image)}.")
        return output_image, output_str

    return inference


@session_logger.set_uuid_logging
def get_gradio_interface(fn_inference: Callable):
    return gr.Interface(
        fn_inference,
        inputs=[
            gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
            gr.Image(type="filepath", label="Input Image")
        ],
        outputs=[
            gr.Image(type="pil", label="Segmentation Output"),
            gr.Textbox(lines=1, placeholder=None, label="Text Output")
        ],
        title=constants.title,
        description=constants.description,
        article=constants.article,
        examples=constants.examples,
        allow_flagging="auto"
    )


logging.info(f"sys.argv:{sys.argv}.")
args = parse_args([])
inference_fn = get_inference_model_by_args(args)
io = get_gradio_interface(inference_fn)
app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)