File size: 3,726 Bytes
afff347
ea37c27
afff347
ea37c27
ca317b2
ea37c27
afff347
 
 
8c54553
afff347
 
 
 
 
 
 
 
 
 
 
ca30e4f
afff347
 
 
 
ca30e4f
 
afff347
 
ee668ff
a7191f1
afff347
 
a7191f1
87752ed
a7191f1
 
 
 
ca317b2
afff347
ca317b2
a7191f1
ca317b2
 
a7191f1
 
afff347
ee668ff
ea37c27
 
afff347
ea37c27
afff347
 
ea37c27
afff347
ea37c27
 
afff347
 
 
 
 
 
ea37c27
afff347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec0b15
afff347
 
ea37c27
afff347
5b853cd
ea37c27
afff347
5b853cd
 
ea37c27
7dc477a
ea37c27
7dc477a
ea37c27
7dc477a
afff347
 
 
 
 
7dc477a
 
 
 
ee668ff
 
 
7dc477a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import time
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer

# import spaces
import argparse

from llava_llama3.model.builder import load_pretrained_model
from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava_llama3.conversation import conv_templates, SeparatorStyle
from llava_llama3.utils import disable_torch_init
from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava_llama3.serve.cli import chat_llava

import requests
from io import BytesIO
import base64
import os
import glob
import pandas as pd
from tqdm import tqdm
import json

root_path = os.path.dirname(os.path.abspath(__file__))
print(f'\033[92m{root_path}\033[0m')
os.environ['GRADIO_TEMP_DIR'] = root_path

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="/mnt/nvme1n1/toby/LLaVA/checkpoints/0806_onlyllava_llava-finma-8B-v0.4-v8/checkpoint-2000")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()

# Load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    args.model_path, 
    None, 
    'llava_llama3', 
    args.load_8bit, 
    args.load_4bit, 
    device=args.device)

def bot_streaming(message, history):
    print(message)
    image_file = None
    if message["files"]:
        if type(message["files"][-1]) == dict:
            image_file = message["files"][-1]["path"]
        else:
            image_file = message["files"][-1]
    else:
        for hist in history:
            if type(hist[0]) == tuple:
                image_file = hist[0][0]
                
    if image_file is None:
        gr.Error("You need to upload an image for LLaVA to work.")
        return
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    def generate():
        print('\033[92mRunning chat\033[0m')
        output = chat_llava(
                    args=args,
                    image_file=image_file,
                    text=message['text'],
                    tokenizer=tokenizer,
                    model=llava_model,
                    image_processor=image_processor,
                    context_len=context_len,
                    streamer=streamer)
        return output

    thread = Thread(target=generate)
    thread.start()
    # thread.join()

    buffer = ""
    # output = generate()
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        time.sleep(0.06)
        yield generated_text_without_prompt

chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="FinLLaVA Demo",
        examples=[
            {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
        ],
        description="",
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)