Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import time | |
| from threading import Thread | |
| import copy | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer | |
| from llava_llama3.model.builder import load_pretrained_model | |
| from llava_llama3.serve.cli import chat_llava | |
| import os | |
| import argparse | |
| # Set environment variables | |
| root_path = os.path.dirname(os.path.abspath(__file__)) | |
| print(f'\033[92m{root_path}\033[0m') | |
| os.environ['GRADIO_TEMP_DIR'] = root_path | |
| # Create a default arguments object | |
| default_args = argparse.Namespace( | |
| model_path="TheFinAI/FinLLaVA", | |
| device="cuda", | |
| conv_mode="llama_3", | |
| temperature=0.7, | |
| max_new_tokens=512, | |
| load_8bit=False, | |
| load_4bit=False | |
| ) | |
| # Load the model | |
| tokenizer, llava_model, image_processor, context_len = load_pretrained_model( | |
| default_args.model_path, | |
| None, | |
| 'llava_llama3', | |
| default_args.load_8bit, | |
| default_args.load_4bit, | |
| device=default_args.device | |
| ) | |
| def bot_streaming(message, history, temperature, max_new_tokens): | |
| image_file = None | |
| if message["files"]: | |
| if isinstance(message["files"][-1], dict): | |
| image_file = message["files"][-1]["path"] | |
| else: | |
| image_file = message["files"][-1] | |
| else: | |
| for hist in history: | |
| if isinstance(hist[0], tuple): | |
| image_file = hist[0][0] | |
| if image_file is None: | |
| gr.Error("You need to upload an image for LLaVA to work.") | |
| return | |
| args = copy.deepcopy(default_args) | |
| args.temperature = temperature | |
| args.max_new_tokens = max_new_tokens | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| def generate(): | |
| print('\033[92mRunning chat\033[0m') | |
| return chat_llava( | |
| args=args, | |
| image_file=image_file, | |
| text=message['text'], | |
| tokenizer=tokenizer, | |
| model=llava_model, | |
| image_processor=image_processor, | |
| context_len=context_len, | |
| streamer=streamer | |
| ) | |
| thread = Thread(target=generate) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| time.sleep(0.06) | |
| yield buffer | |
| # Define CSS styles | |
| css = """ | |
| body { | |
| font-family: Arial, sans-serif; | |
| } | |
| .gradio-container { | |
| max-width: 800px; | |
| margin: auto; | |
| } | |
| .chatbot { | |
| height: 400px; | |
| overflow-y: auto; | |
| } | |
| """ | |
| # Create interface using gr.Blocks | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# FinLLaVA Demo") | |
| chatbot = gr.Chatbot(scale=1) | |
| chat_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| file_types=["image"], | |
| placeholder="Enter message or upload file...", | |
| show_label=False | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=2.0, | |
| step=0.1, | |
| value=default_args.temperature | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens", | |
| minimum=1, | |
| maximum=1024, | |
| step=1, | |
| value=default_args.max_new_tokens | |
| ) | |
| chat_interface = gr.ChatInterface( | |
| fn=bot_streaming, | |
| chatbot=chatbot, | |
| textbox=chat_input, | |
| additional_inputs=[temperature, max_new_tokens], | |
| examples=[ | |
| {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]}, | |
| ], | |
| title="", | |
| description="", | |
| theme="soft", | |
| retry_btn="Retry", | |
| undo_btn="Undo", | |
| clear_btn="Clear", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(api_open=False).launch(share=False, debug=True) | 
 
			
