File size: 1,851 Bytes
3f0a74a
422ac12
052bd8b
 
 
 
 
 
 
422ac12
 
 
 
 
 
 
052bd8b
422ac12
 
052bd8b
 
422ac12
 
052bd8b
422ac12
052bd8b
422ac12
 
 
 
 
 
052bd8b
 
422ac12
 
 
 
 
 
 
 
 
 
 
 
3f0a74a
052bd8b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
from transformers import pipeline, Conversation, AutoTokenizer

# Define the special tokens
bos_token = "<|begin_of_text|>"
eos_token = "<|eot_id|>"
start_header_id = "<|start_header_id|>"
end_header_id = "<|end_header_id|>"

# Load the conversational pipeline and tokenizer
model_id = "H-D-T/Buzz-3b-small-v0.6.3"
chatbot = pipeline("conversational", model=model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

def format_conversation(chat_history):
    formatted_history = ""
    for i, (user, assistant) in enumerate(chat_history):
        user_msg = f"{start_header_id}user{end_header_id}\n\n{user.strip()}{eos_token}"
        assistant_msg = f"{start_header_id}assistant{end_header_id}\n\n{assistant.strip()}{eos_token}"
        if i == 0:
            user_msg = bos_token + user_msg
        formatted_history += user_msg + assistant_msg
    return formatted_history

def predict(message, chat_history):
    chat_history.append(("user", message))
    formatted_history = format_conversation(chat_history)
    conversation = Conversation(formatted_history)
    conversation = chatbot(conversation)
    response = conversation.generated_responses[-1]
    chat_history.append(("assistant", response))
    return "", chat_history

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("# Buzz-3B-Small Conversational Demo")
    with gr.Chatbot() as chatbot_ui:
        chatbot_ui.append({"role": "assistant", "content": "Hi, how can I help you today?"})
    with gr.Row():
        with gr.Column():
            textbox = gr.Textbox(label="Your message:")
        with gr.Column():
            submit_btn = gr.Button("Send")

    chat_history = gr.State([])

    submit_btn.click(predict, inputs=[textbox, chat_history], outputs=[textbox, chat_history])

if __name__ == "__main__":
    demo.queue(max_size=20).launch()