Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,605 Bytes
ea37c27 ca317b2 ea37c27 ca317b2 ea37c27 8c54553 ed5a7bf ca30e4f ee668ff a7191f1 ca317b2 a7191f1 ca317b2 a7191f1 ca317b2 a7191f1 ee668ff 386e329 ea37c27 cec0b15 ca30e4f cec0b15 1d234d5 cec0b15 ea37c27 cec0b15 ea37c27 f07fb5d ea37c27 ee668ff ea37c27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import time
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os
root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()
# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
args.model_path,
None,
'llava_llama3',
args.load_8bit,
args.load_4bit,
device=args.device)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image = None
# Check if there's an image in the current message
if message["files"]:
# message["files"][-1] could be a dictionary or a string
if isinstance(message["files"][-1], dict):
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
# If no image in the current message, look in the history for the last image
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
# Error handling if no image is found
if image is None:
raise gr.Error("You need to upload an image for LLaVA to work.")
# Load the image
image = Image.open(image)
# Generate the prompt for the model
prompt = message['text']
# Use a streamer to generate the output in a streaming fashion
streamer = []
image_file = image if isinstance(image, str) else image.filename
# Define a function to call chat_llava in a separate thread
def generate_output():
output = chat_llava(
args=args,
image_file=image,
text=prompt,
tokenizer=tokenizer,
model=llava_model,
image_processor=image_processor,
context_len=context_len
)
for new_text in output:
streamer.append(new_text)
# Start the generation in a separate thread
thread = Thread(target=generate_output)
thread.start()
# Stream the output
buffer = ""
while thread.is_alive() or streamer:
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
time.sleep(0.1)
# Ensure any remaining text is yielded after the thread completes
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="FinLLaVA",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)
|