File size: 2,332 Bytes
b308128
937be2f
7f877a9
b308128
8eb0f9a
7eaa7b0
937be2f
 
04fc021
85bd1c9
 
edb0bcd
85bd1c9
edb0bcd
85bd1c9
 
f2d6f20
85bd1c9
 
 
 
a450a5f
85bd1c9
a450a5f
85bd1c9
 
 
 
 
 
 
 
8eb0f9a
85bd1c9
8eb0f9a
 
 
 
85bd1c9
8eb0f9a
 
 
ac4f141
85bd1c9
 
 
 
 
 
a450a5f
85bd1c9
 
 
5e8be56
85bd1c9
 
 
 
8c245db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import time

# Load the Vicuna 7B v1.3 LMSys model and tokenizer
model_name = "lmsys/vicuna-7b-v1.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def create_chatbot_tab(description, prompt_placeholder, strategy_labels):
    tab = gr.Tab(description)

    prompt_textbox = gr.Textbox(show_label=False, placeholder=prompt_placeholder)

    chatbots = []
    for strategy_label in strategy_labels:
        with gr.Row():
            vicuna_chatbot = gr.Chatbot(label="vicuna-7b")
            llama_chatbot = gr.Chatbot(label="llama-7b")
            gpt_chatbot = gr.Chatbot(label="gpt-3.5")
        chatbots.append(vicuna_chatbot)

    clear_button = gr.ClearButton([prompt_textbox] + chatbots)

    # Add components within the gr.Blocks context
    with tab:
        gr.Col(prompt_textbox)
        for chatbot in chatbots:
            gr.Col(chatbot)
        gr.Col(clear_button)

    return tab, prompt_textbox, chatbots

def create_submit_function(prompt_textbox, chatbots):
    def respond(message, chat_history):
        input_ids = tokenizer.encode(message, return_tensors="pt")
        output_ids = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
        bot_message = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        chat_history.append((message, bot_message))
        time.sleep(2)
        return "", chat_history

    for chatbot in chatbots:
        prompt_textbox.submit(respond, [prompt_textbox, chatbot], [prompt_textbox, chatbot])

# Create POS and Chunk tabs
pos_tab, pos_prompt_textbox, pos_chatbots = create_chatbot_tab("POS", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"])
chunk_tab, chunk_prompt_textbox, chunk_chatbots = create_chatbot_tab("Chunk", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"])

# Create submit functions for POS and Chunk tabs
create_submit_function(pos_prompt_textbox, pos_chatbots)
create_submit_function(chunk_prompt_textbox, chunk_chatbots)

# Launch the demo with POS and Chunk tabs
demo = gr.Blocks()
demo.append(pos_tab)
demo.append(chunk_tab)
demo.launch()