nugentc commited on
Commit
fbd4cfd
·
1 Parent(s): d7c2aca

rearrange history tracking

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -21,11 +21,17 @@ import gradio as gr
21
 
22
 
23
  def chat(message, history=[]):
24
- new_user_input_ids = tokenizer.encode(message+ tokenizer.eos_token, return_tensors='pt')
25
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
26
- history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
27
- response = tokenizer.decode(history[0]).replace("<|endoftext|>", "")
28
- return history, history, 'haha'
 
 
 
 
 
 
29
 
30
 
31
  def feedback(text):
 
21
 
22
 
23
  def chat(message, history=[]):
24
+ new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
25
+ if len(history) > 0:
26
+ last_set_of_ids = history[len(history)-1][3]
27
+ bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1)
28
+ else:
29
+ new_user_input_ids
30
+ response_ids = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
31
+ response = tokenizer.decode(response_ids).replace("<|endoftext|>", "")
32
+ bot_input_ids = torch.cat([last_set_of_ids, response_ids], dim=-1)
33
+ history.push((message, response, bot_input_ids))
34
+ return response, history, feedback(message)
35
 
36
 
37
  def feedback(text):