teaevo commited on
Commit
07679b3
·
1 Parent(s): 2714773

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -81,11 +81,7 @@ def predict(input, conversation_history): #history=[]):
81
  else:
82
  '''
83
 
84
- global conversation_history
85
-
86
- conversation_history.append(user_input)
87
-
88
- bot_input = dialogpt_tokenizer.encode(input + " ".join(conversation_history)+ tokenizer.eos_token, return_tensors="pt")
89
  chat_history_ids = model.generate(bot_input, max_length=1000, pad_token_id=tokenizer.eos_token_id)
90
  response = tokenizer.decode(chat_history_ids[:, bot_input.shape[-1]:][0], skip_special_tokens=True)
91
  #return response
@@ -107,7 +103,16 @@ def predict(input, conversation_history): #history=[]):
107
 
108
  return response #, history
109
 
 
 
 
 
110
 
 
 
 
 
 
111
  def sqlquery(input):
112
 
113
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
@@ -118,6 +123,8 @@ def sqlquery(input):
118
 
119
 
120
  chat_interface = gr.Interface(
 
 
121
  fn=predict,
122
  theme="default",
123
  css=".footer {display:none !important}",
@@ -125,6 +132,7 @@ chat_interface = gr.Interface(
125
  outputs="text", #["chatbot", "state"],
126
  title="ST Chatbot",
127
  description="Type your message in the box above, and the chatbot will respond.",
 
128
  )
129
 
130
  sql_interface = gr.Interface(
 
81
  else:
82
  '''
83
 
84
+ bot_input = dialogpt_tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
 
 
 
 
85
  chat_history_ids = model.generate(bot_input, max_length=1000, pad_token_id=tokenizer.eos_token_id)
86
  response = tokenizer.decode(chat_history_ids[:, bot_input.shape[-1]:][0], skip_special_tokens=True)
87
  #return response
 
103
 
104
  return response #, history
105
 
106
+ def chatbot_interface(user_input, table=gr.inputs.Textbox()):
107
+ global conversation_history
108
+
109
+ conversation_history.append(user_input)
110
 
111
+ dialog_prompt = "User: " + " ".join(conversation_history) + "\nBot:"
112
+ response = predict(dialog_prompt, conversation_history)
113
+ conversation_history.append(response)
114
+ return "Bot (DialoGPT): " + response
115
+
116
  def sqlquery(input):
117
 
118
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
 
123
 
124
 
125
  chat_interface = gr.Interface(
126
+ fn=chatbot_interface, inputs=["text", "text"], outputs="text", live=True
127
+ '''
128
  fn=predict,
129
  theme="default",
130
  css=".footer {display:none !important}",
 
132
  outputs="text", #["chatbot", "state"],
133
  title="ST Chatbot",
134
  description="Type your message in the box above, and the chatbot will respond.",
135
+ '''
136
  )
137
 
138
  sql_interface = gr.Interface(