Alignment-Lab-AI's picture
Update app.py
422ac12 verified
raw
history blame
1.85 kB
import gradio as gr
from transformers import pipeline, Conversation, AutoTokenizer
# Define the special tokens
bos_token = "<|begin_of_text|>"
eos_token = "<|eot_id|>"
start_header_id = "<|start_header_id|>"
end_header_id = "<|end_header_id|>"
# Load the conversational pipeline and tokenizer
model_id = "H-D-T/Buzz-3b-small-v0.6.3"
chatbot = pipeline("conversational", model=model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def format_conversation(chat_history):
formatted_history = ""
for i, (user, assistant) in enumerate(chat_history):
user_msg = f"{start_header_id}user{end_header_id}\n\n{user.strip()}{eos_token}"
assistant_msg = f"{start_header_id}assistant{end_header_id}\n\n{assistant.strip()}{eos_token}"
if i == 0:
user_msg = bos_token + user_msg
formatted_history += user_msg + assistant_msg
return formatted_history
def predict(message, chat_history):
chat_history.append(("user", message))
formatted_history = format_conversation(chat_history)
conversation = Conversation(formatted_history)
conversation = chatbot(conversation)
response = conversation.generated_responses[-1]
chat_history.append(("assistant", response))
return "", chat_history
with gr.Blocks(css="style.css") as demo:
gr.Markdown("# Buzz-3B-Small Conversational Demo")
with gr.Chatbot() as chatbot_ui:
chatbot_ui.append({"role": "assistant", "content": "Hi, how can I help you today?"})
with gr.Row():
with gr.Column():
textbox = gr.Textbox(label="Your message:")
with gr.Column():
submit_btn = gr.Button("Send")
chat_history = gr.State([])
submit_btn.click(predict, inputs=[textbox, chat_history], outputs=[textbox, chat_history])
if __name__ == "__main__":
demo.queue(max_size=20).launch()