Spaces:
Runtime error
Runtime error
| import time | |
| from threading import Thread | |
| from llava_llama3.serve.cli import chat_llava | |
| from llava_llama3.model.builder import load_pretrained_model | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import spaces | |
| # Model configuration | |
| model_id = "TheFinAI/FinLLaVA" | |
| device = "cuda:0" | |
| load_8bit = False | |
| load_4bit = False | |
| # Load the pretrained model | |
| tokenizer, llava_model, image_processor, context_len = load_pretrained_model( | |
| model_id, | |
| None, | |
| 'llava_llama3', | |
| load_8bit, | |
| load_4bit, | |
| device=device | |
| ) | |
| def bot_streaming(message, history): | |
| print(message) | |
| image = None | |
| # Check if there's an image in the current message | |
| if message["files"]: | |
| # message["files"][-1] could be a dictionary or a string | |
| if isinstance(message["files"][-1], dict): | |
| image = message["files"][-1]["path"] | |
| else: | |
| image = message["files"][-1] | |
| else: | |
| # If no image in the current message, look in the history for the last image | |
| for hist in history: | |
| if isinstance(hist[0], tuple): | |
| image = hist[0][0] | |
| # Error handling if no image is found | |
| if image is None: | |
| raise gr.Error("You need to upload an image for LLaVA to work.") | |
| # Load the image | |
| image = Image.open(image) | |
| # Generate the prompt for the model | |
| prompt = message['text'] | |
| # Call the chat_llava function to generate the output | |
| output = chat_llava( | |
| args=None, | |
| image_file=image, | |
| text=prompt, | |
| tokenizer=tokenizer, | |
| model=llava_model, | |
| image_processor=image_processor, | |
| context_len=context_len | |
| ) | |
| # Stream the output | |
| buffer = "" | |
| for new_text in output: | |
| buffer += new_text | |
| yield buffer | |
| chatbot=gr.Chatbot(scale=1) | |
| chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
| with gr.Blocks(fill_height=True, ) as demo: | |
| gr.ChatInterface( | |
| fn=bot_streaming, | |
| title="LLaVA Llama-3-8B", | |
| examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]}, | |
| {"text": "How to make this pastry?", "files": ["./baklava.png"]}], | |
| stop_btn="Stop Generation", | |
| multimodal=True, | |
| textbox=chat_input, | |
| chatbot=chatbot, | |
| ) | |
| demo.queue(api_open=False) | |
| demo.launch(show_api=False, share=False) | |