File size: 1,307 Bytes
12d1a48
f364d75
 
12d1a48
5fa43e8
 
 
12d1a48
5fa43e8
 
 
 
8a0b89f
5fa43e8
 
 
f364d75
5fa43e8
 
 
40ae621
5fa43e8
 
40ae621
5fa43e8
12d1a48
5fa43e8
11b75fb
 
5fa43e8
 
 
 
12d1a48
 
 
3738e2a
40ae621
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load the DialoGPT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")

# Respond function
def respond(message, chat_history=None):
    if chat_history is None:
        chat_history = []

    # Encode the user input and append to the chat history
    new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
    bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids

    # Generate the bot's response
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
    bot_message = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)

    # Update chat history
    chat_history = chat_history_ids.tolist()

    return bot_message, chat_history

# Gradio Interface
demo = gr.Interface(
    fn=respond,
    inputs=["text", gr.State()],
    outputs=["text", gr.State()],
    title="DialoGPT Chatbot",
    description="A chatbot powered by Microsoft's DialoGPT.",
)

if __name__ == "__main__":
    demo.launch()