File size: 3,605 Bytes
ea37c27
 
ca317b2
 
ea37c27
ca317b2
ea37c27
8c54553
ed5a7bf
ca30e4f
 
 
 
ee668ff
a7191f1
 
 
 
 
 
 
 
 
ca317b2
a7191f1
ca317b2
a7191f1
ca317b2
 
a7191f1
 
 
ee668ff
386e329
ea37c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec0b15
 
 
ca30e4f
 
cec0b15
 
 
1d234d5
cec0b15
 
 
 
 
 
 
 
 
 
 
 
 
ea37c27
 
 
cec0b15
 
 
 
 
 
 
 
 
 
ea37c27
 
 
 
 
 
 
 
 
f07fb5d
ea37c27
 
 
 
 
 
 
ee668ff
 
 
ea37c27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import time
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os

root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()

# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    args.model_path, 
    None, 
    'llava_llama3', 
    args.load_8bit, 
    args.load_4bit, 
    device=args.device)

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    image = None
    
    # Check if there's an image in the current message
    if message["files"]:
        # message["files"][-1] could be a dictionary or a string
        if isinstance(message["files"][-1], dict):
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        # If no image in the current message, look in the history for the last image
        for hist in history:
            if isinstance(hist[0], tuple):
                image = hist[0][0]
    
    # Error handling if no image is found
    if image is None:
        raise gr.Error("You need to upload an image for LLaVA to work.")
    
    # Load the image
    image = Image.open(image)
    
    # Generate the prompt for the model
    prompt = message['text']
    
    # Use a streamer to generate the output in a streaming fashion
    streamer = []

    image_file = image if isinstance(image, str) else image.filename

    # Define a function to call chat_llava in a separate thread
    def generate_output():
        output = chat_llava(
            args=args,
            image_file=image,
            text=prompt,
            tokenizer=tokenizer,
            model=llava_model,
            image_processor=image_processor,
            context_len=context_len
        )
        for new_text in output:
            streamer.append(new_text)
    
    # Start the generation in a separate thread
    thread = Thread(target=generate_output)
    thread.start()
    
    # Stream the output
    buffer = ""
    while thread.is_alive() or streamer:
        while streamer:
            new_text = streamer.pop(0)
            buffer += new_text
            yield buffer
        time.sleep(0.1)
    
    # Ensure any remaining text is yielded after the thread completes
    while streamer:
        new_text = streamer.pop(0)
        buffer += new_text
        yield buffer


chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
    fn=bot_streaming,
    title="FinLLaVA",
    examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
              {"text": "How to make this pastry?", "files": ["./baklava.png"]}],

    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)