nugentc commited on
Commit
0a042d8
·
1 Parent(s): 703a8b7

add chat agent

Browse files
Files changed (2) hide show
  1. app.py +24 -23
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,27 +1,28 @@
1
- import random
 
2
 
3
- import gradio as gr
 
 
 
 
 
 
 
 
4
 
 
 
5
 
6
- def chat(message, history):
7
- history = history or []
8
- if message.startswith("How many"):
9
- response = random.randint(1, 10)
10
- elif message.startswith("How"):
11
- response = random.choice(["Great", "Good", "Okay", "Bad"])
12
- elif message.startswith("Where"):
13
- response = random.choice(["Here", "There", "Somewhere"])
14
- else:
15
- response = "I don't know"
16
- history.append((message, response))
17
- return history, history
18
 
19
- chatbot = gr.Chatbot(color_map=("green", "gray"))
20
- demo = gr.Interface(
21
- chat,
22
- ["text", "state"],
23
- [chatbot, "state"],
24
- allow_screenshot=False,
25
- allow_flagging="never",
26
- )
27
- demo.launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
 
4
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
5
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
6
+
7
+ def predict(input, history=[]):
8
+ # tokenize the new input sentence
9
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
10
+
11
+ # append the new user input tokens to the chat history
12
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
13
 
14
+ # generate a response
15
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
16
 
17
+ # convert the tokens to text, and then split the responses into the right format
18
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
19
+ response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
20
+ return response, history
21
+
22
+ import gradio as gr
 
 
 
 
 
 
23
 
24
+ gr.Interface(fn=predict,
25
+ theme="default",
26
+ css=".footer {display:none !important}",
27
+ inputs=["text", "state"],
28
+ outputs=["chatbot", "state"]).launch()
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch