Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import time | |
# Load the Vicuna 7B v1.3 LMSys model and tokenizer | |
model_name = "lmsys/vicuna-7b-v1.3" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
def create_chatbot_tab(description, prompt_placeholder, strategy_labels): | |
tab = gr.Tab(description) | |
prompt_textbox = gr.Textbox(show_label=False, placeholder=prompt_placeholder) | |
chatbots = [] | |
for strategy_label in strategy_labels: | |
with gr.Row(): | |
vicuna_chatbot = gr.Chatbot(label="vicuna-7b") | |
llama_chatbot = gr.Chatbot(label="llama-7b") | |
gpt_chatbot = gr.Chatbot(label="gpt-3.5") | |
chatbots.append(vicuna_chatbot) | |
clear_button = gr.ClearButton([prompt_textbox] + chatbots) | |
# Add components within the gr.Blocks context | |
with tab: | |
gr.Col(prompt_textbox) | |
for chatbot in chatbots: | |
gr.Col(chatbot) | |
gr.Col(clear_button) | |
return tab, prompt_textbox, chatbots | |
def create_submit_function(prompt_textbox, chatbots): | |
def respond(message, chat_history): | |
input_ids = tokenizer.encode(message, return_tensors="pt") | |
output_ids = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2) | |
bot_message = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
chat_history.append((message, bot_message)) | |
time.sleep(2) | |
return "", chat_history | |
for chatbot in chatbots: | |
prompt_textbox.submit(respond, [prompt_textbox, chatbot], [prompt_textbox, chatbot]) | |
# Create POS and Chunk tabs | |
pos_tab, pos_prompt_textbox, pos_chatbots = create_chatbot_tab("POS", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"]) | |
chunk_tab, chunk_prompt_textbox, chunk_chatbots = create_chatbot_tab("Chunk", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"]) | |
# Create submit functions for POS and Chunk tabs | |
create_submit_function(pos_prompt_textbox, pos_chatbots) | |
create_submit_function(chunk_prompt_textbox, chunk_chatbots) | |
# Launch the demo with POS and Chunk tabs | |
demo = gr.Blocks() | |
demo.append(pos_tab) | |
demo.append(chunk_tab) | |
demo.launch() | |