tstone87 commited on
Commit
9111270
·
verified ·
1 Parent(s): 8c1843a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
 
3
 
4
  # Load the model and tokenizer
5
  model_name = "microsoft/DialoGPT-small"
@@ -8,30 +9,48 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
  # Function to generate a response
10
  def dialoGPT_response(user_input, history):
 
 
 
 
 
 
11
  # Encode the new user input, with the history
12
  new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
13
 
14
  # Append the new user input tokens to the chat history
15
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) if history else new_user_input_ids
16
 
17
  # Generate a response
18
  chat_history_ids = model.generate(
19
  bot_input_ids,
20
- max_length=1000,
21
- pad_token_id=tokenizer.eos_token_id
 
22
  )
23
 
24
- # Decode the response
25
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
26
- return response
 
 
 
 
27
 
28
  # Gradio interface
29
  iface = gr.Interface(
30
  fn=dialoGPT_response,
31
- inputs=[gr.Textbox(placeholder="Enter your message..."), "state"],
32
- outputs="text",
 
 
 
 
 
 
33
  title="DialoGPT Chat",
34
- description="Chat with DialoGPT-small model."
 
35
  )
36
 
37
  iface.launch()
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
+ import torch
4
 
5
  # Load the model and tokenizer
6
  model_name = "microsoft/DialoGPT-small"
 
9
 
10
  # Function to generate a response
11
  def dialoGPT_response(user_input, history):
12
+ # Convert history to tensor if it's not None
13
+ if history:
14
+ history_tensor = torch.LongTensor(history)
15
+ else:
16
+ history_tensor = torch.LongTensor([])
17
+
18
  # Encode the new user input, with the history
19
  new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
20
 
21
  # Append the new user input tokens to the chat history
22
+ bot_input_ids = torch.cat([history_tensor, new_user_input_ids], dim=-1)
23
 
24
  # Generate a response
25
  chat_history_ids = model.generate(
26
  bot_input_ids,
27
+ max_length=1000, # You might want to adjust this based on your needs
28
+ pad_token_id=tokenizer.eos_token_id,
29
+ no_repeat_ngram_size=3 # This prevents repeating phrases
30
  )
31
 
32
+ # Decode the response, keeping only the new tokens
33
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
34
+
35
+ # Update history with new input and response
36
+ new_history = chat_history_ids.tolist()[0] # Convert tensor to list for Gradio state
37
+
38
+ return response, new_history
39
 
40
  # Gradio interface
41
  iface = gr.Interface(
42
  fn=dialoGPT_response,
43
+ inputs=[
44
+ gr.Textbox(placeholder="Enter your message..."),
45
+ "state"
46
+ ],
47
+ outputs=[
48
+ "text", # The response
49
+ "state" # Updated history
50
+ ],
51
  title="DialoGPT Chat",
52
+ description="Chat with DialoGPT-small model. Your conversation history is maintained.",
53
+ allow_flagging="never" # Disabling flagging since this is a chat model
54
  )
55
 
56
  iface.launch()