Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,360 Bytes
ea37c27 ca317b2 ea37c27 ca317b2 ea37c27 8c54553 ed5a7bf ca30e4f 7dc477a ca30e4f ee668ff a7191f1 87752ed a7191f1 ca317b2 a7191f1 ca317b2 a7191f1 ca317b2 a7191f1 7dc477a ee668ff 386e329 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 5b853cd ea37c27 5b853cd cec0b15 5b853cd cec0b15 ea37c27 5b853cd ea37c27 cec0b15 5b853cd ea37c27 5b853cd ea37c27 7dc477a ea37c27 7dc477a ea37c27 7dc477a 1736895 7dc477a ee668ff 7dc477a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os
import time
root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()
# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
args.model_path,
None,
'llava_llama3',
args.load_8bit,
args.load_4bit,
device=args.device
)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image_path = None
# Check if there's an image in the current message
if message["files"]:
# message["files"][-1] could be a dictionary or a string
if isinstance(message["files"][-1], dict):
image_path = message["files"][-1]["path"]
else:
image_path = message["files"][-1]
else:
# If no image in the current message, look in the history for the last image path
for hist in history:
if isinstance(hist[0], tuple):
image_path = hist[0][0]
# Error handling if no image path is found
if image_path is None:
raise gr.Error("You need to upload an image for LLaVA to work.")
# If the image_path is a string, no need to load it into a PIL image
# Just use the path directly in the next steps
print(f"\033[91m{image_path}, {type(image_path)}\033[0m")
# Generate the prompt for the model
prompt = message['text']
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Set up the generation arguments, including the streamer
generation_kwargs = dict(
args=args,
image_file=image_path,
text=prompt,
tokenizer=tokenizer,
model=llava_model,
streamer=streamer
image_processor=image_processor, # todo: input model name or path
context_len=context_len)
# Define the function to call `chat_llava` with the given arguments
def generate_output(generation_kwargs):
chat_llava(**generation_kwargs)
# Start the generation in a separate thread
thread = Thread(target=generate_output, kwargs=generation_kwargs)
thread.start()
# Initialize a buffer to accumulate the generated text
buffer = ""
# Allow the generation to start
time.sleep(0.5)
# Iterate over the streamer to handle the incoming text in chunks
for new_text in streamer:
# Look for the end of text token and remove it
if "<|eot_id|>" in new_text:
new_text = new_text.split("<|eot_id|>")[0]
# Add the new text to the buffer
buffer += new_text
# Remove the prompt from the generated text (if necessary)
generated_text_without_prompt = buffer[len(prompt):]
# Simulate processing time (optional)
time.sleep(0.06)
# Yield the current generated text for further processing or display
yield generated_text_without_prompt
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="FinLLaVA",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]},
{"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False) |