Spaces:
Running
on
Zero
Running
on
Zero
import time | |
from threading import Thread | |
import copy | |
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer | |
from llava_llama3.model.builder import load_pretrained_model | |
from llava_llama3.serve.cli import chat_llava | |
import os | |
import argparse | |
# Set environment variables | |
root_path = os.path.dirname(os.path.abspath(__file__)) | |
print(f'\033[92m{root_path}\033[0m') | |
os.environ['GRADIO_TEMP_DIR'] = root_path | |
# Create a default arguments object | |
default_args = argparse.Namespace( | |
model_path="TheFinAI/FinLLaVA", | |
device="cuda", | |
conv_mode="llama_3", | |
temperature=0.7, | |
max_new_tokens=512, | |
load_8bit=False, | |
load_4bit=False | |
) | |
# Load the model | |
tokenizer, llava_model, image_processor, context_len = load_pretrained_model( | |
default_args.model_path, | |
None, | |
'llava_llama3', | |
default_args.load_8bit, | |
default_args.load_4bit, | |
device=default_args.device | |
) | |
def bot_streaming(message, history, temperature, max_new_tokens): | |
image_file = None | |
if message["files"]: | |
if isinstance(message["files"][-1], dict): | |
image_file = message["files"][-1]["path"] | |
else: | |
image_file = message["files"][-1] | |
else: | |
for hist in history: | |
if isinstance(hist[0], tuple): | |
image_file = hist[0][0] | |
if image_file is None: | |
gr.Error("You need to upload an image for LLaVA to work.") | |
return | |
args = copy.deepcopy(default_args) | |
args.temperature = temperature | |
args.max_new_tokens = max_new_tokens | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
def generate(): | |
print('\033[92mRunning chat\033[0m') | |
return chat_llava( | |
args=args, | |
image_file=image_file, | |
text=message['text'], | |
tokenizer=tokenizer, | |
model=llava_model, | |
image_processor=image_processor, | |
context_len=context_len, | |
streamer=streamer | |
) | |
thread = Thread(target=generate) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.06) | |
yield buffer | |
# Define CSS styles | |
css = """ | |
body { | |
font-family: Arial, sans-serif; | |
} | |
.gradio-container { | |
max-width: 800px; | |
margin: auto; | |
} | |
.chatbot { | |
height: 400px; | |
overflow-y: auto; | |
} | |
""" | |
# Create interface using gr.Blocks | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# FinLLaVA Demo") | |
chatbot = gr.Chatbot(scale=1) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_types=["image"], | |
placeholder="Enter message or upload file...", | |
show_label=False | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=2.0, | |
step=0.1, | |
value=default_args.temperature | |
) | |
max_new_tokens = gr.Slider( | |
label="Max New Tokens", | |
minimum=1, | |
maximum=1024, | |
step=1, | |
value=default_args.max_new_tokens | |
) | |
chat_interface = gr.ChatInterface( | |
fn=bot_streaming, | |
chatbot=chatbot, | |
textbox=chat_input, | |
additional_inputs=[temperature, max_new_tokens], | |
examples=[ | |
{"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]}, | |
], | |
title="", | |
description="", | |
theme="soft", | |
retry_btn="Retry", | |
undo_btn="Undo", | |
clear_btn="Clear", | |
) | |
if __name__ == "__main__": | |
demo.queue(api_open=False).launch(share=False, debug=True) |