Spaces:
Sleeping
Sleeping
File size: 1,307 Bytes
12d1a48 f364d75 12d1a48 5fa43e8 12d1a48 5fa43e8 8a0b89f 5fa43e8 f364d75 5fa43e8 40ae621 5fa43e8 40ae621 5fa43e8 12d1a48 5fa43e8 11b75fb 5fa43e8 12d1a48 3738e2a 40ae621 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the DialoGPT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
# Respond function
def respond(message, chat_history=None):
if chat_history is None:
chat_history = []
# Encode the user input and append to the chat history
new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids
# Generate the bot's response
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
bot_message = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# Update chat history
chat_history = chat_history_ids.tolist()
return bot_message, chat_history
# Gradio Interface
demo = gr.Interface(
fn=respond,
inputs=["text", gr.State()],
outputs=["text", gr.State()],
title="DialoGPT Chatbot",
description="A chatbot powered by Microsoft's DialoGPT.",
)
if __name__ == "__main__":
demo.launch()
|