File size: 3,613 Bytes
b308128
8c245db
 
f9ca505
b308128
7eaa7b0
 
 
 
04fc021
d4e59c1
 
 
 
 
 
 
 
8c245db
28ca6ce
cab4ff3
 
df3b804
cab4ff3
df3b804
 
 
 
 
 
 
 
 
 
 
 
 
cab4ff3
 
28ca6ce
df3b804
cab4ff3
df3b804
 
 
 
 
 
 
 
 
 
 
cab4ff3
df3b804
 
 
cab4ff3
28ca6ce
cab4ff3
df3b804
ac4f141
df3b804
7eaa7b0
 
 
8c245db
 
 
ac4f141
d4e59c1
 
5e8be56
8c245db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import random
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load Vicuna 7B model and tokenizer
model_name = "lmsys/vicuna-7b-v1.3"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def respond_vicuna(message, chat_history, vicuna_chatbot):
    input_ids = tokenizer.encode(message, return_tensors="pt")
    output = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
    bot_message = tokenizer.decode(output[0], skip_special_tokens=True)
    chat_history.append((message, bot_message))
    time.sleep(2)
    return "", chat_history

with gr.Blocks() as demo:
    gr.Markdown("# LLM Evaluator With Linguistic Scrutiny")

    with gr.Tab("POS"):
        gr.Markdown("Strategy 1 QA")
        with gr.Row():
            vicuna_chatbot1 = gr.Chatbot(label="vicuna-7b", live=True)
            llama_chatbot1 = gr.Chatbot(label="llama-7b", live=False)
            gpt_chatbot1 = gr.Chatbot(label="gpt-3.5", live=False)
        gr.Markdown("Strategy 2 Instruction")
        with gr.Row():
            vicuna_chatbot2 = gr.Chatbot(label="vicuna-7b", live=True)
            llama_chatbot2 = gr.Chatbot(label="llama-7b", live=False)
            gpt_chatbot2 = gr.Chatbot(label="gpt-3.5", live=False)
        gr.Markdown("Strategy 3 Structured Prompting")
        with gr.Row():
            vicuna_chatbot3 = gr.Chatbot(label="vicuna-7b", live=True)
            llama_chatbot3 = gr.Chatbot(label="llama-7b", live=False)
            gpt_chatbot3 = gr.Chatbot(label="gpt-3.5", live=False)
        with gr.Row():
            prompt = gr.Textbox(show_label=False, placeholder="Enter prompt")
            send_button_POS = gr.Button("Send", scale=0)
        clear = gr.ClearButton([prompt, vicuna_chatbot1])
    with gr.Tab("Chunk"):
        gr.Markdown("Strategy 1 QA")
        with gr.Row():
            vicuna_chatbot1_chunk = gr.Chatbot(label="vicuna-7b", live=True)
            llama_chatbot1_chunk = gr.Chatbot(label="llama-7b", live=False)
            gpt_chatbot1_chunk = gr.Chatbot(label="gpt-3.5", live=False)
        gr.Markdown("Strategy 2 Instruction")
        with gr.Row():
            vicuna_chatbot2_chunk = gr.Chatbot(label="vicuna-7b", live=True)
            llama_chatbot2_chunk = gr.Chatbot(label="llama-7b", live=False)
            gpt_chatbot2_chunk = gr.Chatbot(label="gpt-3.5", live=False)
        gr.Markdown("Strategy 3 Structured Prompting")
        with gr.Row():
            vicuna_chatbot3_chunk = gr.Chatbot(label="vicuna-7b", live=True)
            llama_chatbot3_chunk = gr.Chatbot(label="llama-7b", live=False)
            gpt_chatbot3_chunk = gr.Chatbot(label="gpt-3.5", live=False)
        with gr.Row():
            prompt_chunk = gr.Textbox(show_label=False, placeholder="Enter prompt")
            send_button_Chunk = gr.Button("Send", scale=0)
        clear = gr.ClearButton([prompt_chunk, vicuna_chatbot1_chunk])

    def respond(message, chat_history):
        input_ids = tokenizer.encode(message, return_tensors="pt")
        output = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
        bot_message = tokenizer.decode(output[0], skip_special_tokens=True)
        chat_history.append((message, bot_message))
        time.sleep(2)
        return "", chat_history

    # Replace the old respond function with the new general function for Vicuna
    prompt.submit(lambda message, chat_history: respond_vicuna(message, chat_history, vicuna_chatbot1), [prompt, vicuna_chatbot1, vicuna_chatbot1_chunk])

demo.launch()