import gradio as gr from llava_llama3.serve.cli import chat_llava from llava_llama3.model.builder import load_pretrained_model from PIL import Image import torch import spaces # Model configuration model_path = "TheFinAI/FinLLaVA" device = "cuda" conv_mode = "llama_3" temperature = 0 max_new_tokens = 512 load_8bit = False load_4bit = False # Load the pretrained model tokenizer, llava_model, image_processor, context_len = load_pretrained_model( model_path, None, 'llava_llama3', load_8bit, load_4bit, device=device ) # Define the prediction function @spaces.GPU def bot_streaming(image, text, history): output = chat_llava( args=None, image_file=image, text=text, tokenizer=tokenizer, model=llava_model, image_processor=image_processor, context_len=context_len ) history.append((text, output)) return history, gr.update(value="") # Create the Gradio interface with gr.Blocks() as demo: chatbot = gr.Chatbot(label="FinLLaVA Chatbot") image_input = gr.Image(type="filepath", label="Upload Image") text_input = gr.Textbox(label="Enter your message") submit_btn = gr.Button("Submit") # Define interaction: when submit is clicked, call bot_streaming and update the chatbot submit_btn.click(fn=bot_streaming, inputs=[image_input, text_input, chatbot], outputs=[chatbot, text_input]) # Add example inputs gr.Examples( examples=[["./bee.jpg", "What is on the flower?"], ["./baklava.png", "How to make this pastry?"]], inputs=[image_input, text_input] ) # Launch the Gradio app demo.queue(api_open=False) demo.launch(show_api=False, share=False)