teaevo commited on
Commit
f6a2d88
·
1 Parent(s): 04883bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -35,8 +35,6 @@ data = {
35
  }
36
  table = pd.DataFrame.from_dict(data)
37
 
38
- step = 0
39
-
40
  def chatbot_response(user_message):
41
  # Generate chatbot response using the chatbot model
42
  #inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
@@ -47,7 +45,10 @@ def chatbot_response(user_message):
47
  new_user_input_ids = chatbot_tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt')
48
 
49
  # append the new user input tokens to the chat history
50
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
 
 
 
51
 
52
  # generated a response while limiting the total chat history to 1000 tokens,
53
  chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
@@ -55,7 +56,6 @@ def chatbot_response(user_message):
55
  # pretty print last ouput tokens from bot
56
  response = "DialoGPT: {}".format(chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
57
 
58
- step += 1
59
  return response
60
 
61
  def sql_response(user_query):
 
35
  }
36
  table = pd.DataFrame.from_dict(data)
37
 
 
 
38
  def chatbot_response(user_message):
39
  # Generate chatbot response using the chatbot model
40
  #inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
 
45
  new_user_input_ids = chatbot_tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt')
46
 
47
  # append the new user input tokens to the chat history
48
+ if bot_input_ids is None:
49
+ bot_input_ids = new_user_input_ids
50
+ else:
51
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
52
 
53
  # generated a response while limiting the total chat history to 1000 tokens,
54
  chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
 
56
  # pretty print last ouput tokens from bot
57
  response = "DialoGPT: {}".format(chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
58
 
 
59
  return response
60
 
61
  def sql_response(user_query):