File size: 4,360 Bytes
ea37c27
ca317b2
 
ea37c27
ca317b2
ea37c27
8c54553
ed5a7bf
ca30e4f
7dc477a
ca30e4f
 
 
ee668ff
a7191f1
 
 
 
87752ed
a7191f1
 
 
 
ca317b2
a7191f1
ca317b2
a7191f1
ca317b2
 
a7191f1
 
7dc477a
 
ee668ff
386e329
ea37c27
 
1736895
ea37c27
 
 
 
 
1736895
ea37c27
1736895
ea37c27
1736895
ea37c27
 
1736895
ea37c27
1736895
 
ea37c27
 
1736895
 
 
ea37c27
 
 
5b853cd
ea37c27
5b853cd
 
 
 
 
 
 
 
 
 
 
 
 
 
cec0b15
 
5b853cd
cec0b15
ea37c27
5b853cd
ea37c27
cec0b15
5b853cd
 
 
 
 
 
 
 
 
 
ea37c27
5b853cd
 
 
 
 
 
 
 
 
ea37c27
7dc477a
ea37c27
7dc477a
ea37c27
7dc477a
 
 
1736895
 
7dc477a
 
 
 
ee668ff
 
 
7dc477a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os
import time

root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()

# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    args.model_path, 
    None, 
    'llava_llama3', 
    args.load_8bit, 
    args.load_4bit, 
    device=args.device
)

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    image_path = None
    
    # Check if there's an image in the current message
    if message["files"]:
        # message["files"][-1] could be a dictionary or a string
        if isinstance(message["files"][-1], dict):
            image_path = message["files"][-1]["path"]
        else:
            image_path = message["files"][-1]
    else:
        # If no image in the current message, look in the history for the last image path
        for hist in history:
            if isinstance(hist[0], tuple):
                image_path = hist[0][0]
    
    # Error handling if no image path is found
    if image_path is None:
        raise gr.Error("You need to upload an image for LLaVA to work.")
    
    # If the image_path is a string, no need to load it into a PIL image
    # Just use the path directly in the next steps
    print(f"\033[91m{image_path}, {type(image_path)}\033[0m")
    
    # Generate the prompt for the model
    prompt = message['text']
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Set up the generation arguments, including the streamer
    generation_kwargs = dict(
        args=args,
        image_file=image_path,
        text=prompt,
        tokenizer=tokenizer,
        model=llava_model,
        streamer=streamer
        image_processor=image_processor,  # todo: input model name or path
        context_len=context_len)
    
    # Define the function to call `chat_llava` with the given arguments
    def generate_output(generation_kwargs):
        chat_llava(**generation_kwargs)
    
    # Start the generation in a separate thread
    thread = Thread(target=generate_output, kwargs=generation_kwargs)
    thread.start()
    
    # Initialize a buffer to accumulate the generated text
    buffer = ""
    
    # Allow the generation to start
    time.sleep(0.5)
    
    # Iterate over the streamer to handle the incoming text in chunks
    for new_text in streamer:
        # Look for the end of text token and remove it
        if "<|eot_id|>" in new_text:
            new_text = new_text.split("<|eot_id|>")[0]
        
        # Add the new text to the buffer
        buffer += new_text
    
        # Remove the prompt from the generated text (if necessary)
        generated_text_without_prompt = buffer[len(prompt):]
    
        # Simulate processing time (optional)
        time.sleep(0.06)
    
        # Yield the current generated text for further processing or display
        yield generated_text_without_prompt

chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="FinLLaVA",
        examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
                  {"text": "How to make this pastry?", "files": ["./baklava.png"]},
                  {"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}],
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)