LingEval / app.py
research14's picture
test
85bd1c9
raw
history blame
2.33 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
# Load the Vicuna 7B v1.3 LMSys model and tokenizer
model_name = "lmsys/vicuna-7b-v1.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def create_chatbot_tab(description, prompt_placeholder, strategy_labels):
tab = gr.Tab(description)
prompt_textbox = gr.Textbox(show_label=False, placeholder=prompt_placeholder)
chatbots = []
for strategy_label in strategy_labels:
with gr.Row():
vicuna_chatbot = gr.Chatbot(label="vicuna-7b")
llama_chatbot = gr.Chatbot(label="llama-7b")
gpt_chatbot = gr.Chatbot(label="gpt-3.5")
chatbots.append(vicuna_chatbot)
clear_button = gr.ClearButton([prompt_textbox] + chatbots)
# Add components within the gr.Blocks context
with tab:
gr.Col(prompt_textbox)
for chatbot in chatbots:
gr.Col(chatbot)
gr.Col(clear_button)
return tab, prompt_textbox, chatbots
def create_submit_function(prompt_textbox, chatbots):
def respond(message, chat_history):
input_ids = tokenizer.encode(message, return_tensors="pt")
output_ids = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
bot_message = tokenizer.decode(output_ids[0], skip_special_tokens=True)
chat_history.append((message, bot_message))
time.sleep(2)
return "", chat_history
for chatbot in chatbots:
prompt_textbox.submit(respond, [prompt_textbox, chatbot], [prompt_textbox, chatbot])
# Create POS and Chunk tabs
pos_tab, pos_prompt_textbox, pos_chatbots = create_chatbot_tab("POS", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"])
chunk_tab, chunk_prompt_textbox, chunk_chatbots = create_chatbot_tab("Chunk", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"])
# Create submit functions for POS and Chunk tabs
create_submit_function(pos_prompt_textbox, pos_chatbots)
create_submit_function(chunk_prompt_textbox, chunk_chatbots)
# Launch the demo with POS and Chunk tabs
demo = gr.Blocks()
demo.append(pos_tab)
demo.append(chunk_tab)
demo.launch()