FinLLaVA / app.py
TobyYang7's picture
Update app.py
ea37c27 verified
raw
history blame
2.42 kB
import time
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import spaces
# Model configuration
model_id = "TheFinAI/FinLLaVA"
device = "cuda:0"
load_8bit = False
load_4bit = False
# Load the pretrained model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
model_id,
None,
'llava_llama3',
load_8bit,
load_4bit,
device=device
)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image = None
# Check if there's an image in the current message
if message["files"]:
# message["files"][-1] could be a dictionary or a string
if isinstance(message["files"][-1], dict):
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
# If no image in the current message, look in the history for the last image
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
# Error handling if no image is found
if image is None:
raise gr.Error("You need to upload an image for LLaVA to work.")
# Load the image
image = Image.open(image)
# Generate the prompt for the model
prompt = message['text']
# Call the chat_llava function to generate the output
output = chat_llava(
args=None,
image_file=image,
text=prompt,
tokenizer=tokenizer,
model=llava_model,
image_processor=image_processor,
context_len=context_len
)
# Stream the output
buffer = ""
for new_text in output:
buffer += new_text
yield buffer
chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="LLaVA Llama-3-8B",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)