artifix / app.py
prakhardoneria's picture
Changed Structure
5fa43e8 verified
raw
history blame
1.31 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the DialoGPT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
# Respond function
def respond(message, chat_history=None):
if chat_history is None:
chat_history = []
# Encode the user input and append to the chat history
new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids
# Generate the bot's response
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
bot_message = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# Update chat history
chat_history = chat_history_ids.tolist()
return bot_message, chat_history
# Gradio Interface
demo = gr.Interface(
fn=respond,
inputs=["text", gr.State()],
outputs=["text", gr.State()],
title="DialoGPT Chatbot",
description="A chatbot powered by Microsoft's DialoGPT.",
)
if __name__ == "__main__":
demo.launch()