import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import time # Load the Vicuna 7B v1.3 LMSys model and tokenizer model_name = "lmsys/vicuna-7b-v1.3" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) def create_chatbot_tab(description, prompt_placeholder, strategy_labels): tab = gr.Tab(description) prompt_textbox = gr.Textbox(show_label=False, placeholder=prompt_placeholder) chatbots = [] for strategy_label in strategy_labels: with gr.Row(): vicuna_chatbot = gr.Chatbot(label="vicuna-7b") llama_chatbot = gr.Chatbot(label="llama-7b") gpt_chatbot = gr.Chatbot(label="gpt-3.5") chatbots.append(vicuna_chatbot) clear_button = gr.ClearButton([prompt_textbox] + chatbots) # Add components within the gr.Blocks context with tab: gr.Col(prompt_textbox) for chatbot in chatbots: gr.Col(chatbot) gr.Col(clear_button) return tab, prompt_textbox, chatbots def create_submit_function(prompt_textbox, chatbots): def respond(message, chat_history): input_ids = tokenizer.encode(message, return_tensors="pt") output_ids = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2) bot_message = tokenizer.decode(output_ids[0], skip_special_tokens=True) chat_history.append((message, bot_message)) time.sleep(2) return "", chat_history for chatbot in chatbots: prompt_textbox.submit(respond, [prompt_textbox, chatbot], [prompt_textbox, chatbot]) # Create POS and Chunk tabs pos_tab, pos_prompt_textbox, pos_chatbots = create_chatbot_tab("POS", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"]) chunk_tab, chunk_prompt_textbox, chunk_chatbots = create_chatbot_tab("Chunk", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"]) # Create submit functions for POS and Chunk tabs create_submit_function(pos_prompt_textbox, pos_chatbots) create_submit_function(chunk_prompt_textbox, chunk_chatbots) # Launch the demo with POS and Chunk tabs demo = gr.Blocks() demo.append(pos_tab) demo.append(chunk_tab) demo.launch()