hertogateis commited on
Commit
125c9f3
·
verified ·
1 Parent(s): 4f1a79d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -19,8 +19,15 @@ def generate_response(input_text):
19
  # Encode the new user input, add end of string token
20
  new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
21
 
22
- # Append the new user input tokens to the chat history
23
- bot_input_ids = torch.cat([torch.tensor(st.session_state['history']).to(device), new_user_input_ids], dim=-1) if st.session_state['history'] else new_user_input_ids
 
 
 
 
 
 
 
24
 
25
  # Generate a response from the model
26
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, top_k=50, top_p=0.95, temperature=0.7)
@@ -29,7 +36,7 @@ def generate_response(input_text):
29
  chat_history_ids = chat_history_ids[:, bot_input_ids.shape[-1]:] # only take the latest generated tokens
30
  bot_output = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True)
31
 
32
- # Update session state history with the new tokens
33
  st.session_state['history'] = chat_history_ids[0].tolist()
34
 
35
  return bot_output
 
19
  # Encode the new user input, add end of string token
20
  new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
21
 
22
+ # If there is conversation history, append the new input to it
23
+ if st.session_state['history']:
24
+ # Convert history to a 2D tensor (batch_size x seq_len)
25
+ history_tensor = torch.tensor(st.session_state['history']).unsqueeze(0).to(device)
26
+ # Concatenate history with the new input
27
+ bot_input_ids = torch.cat([history_tensor, new_user_input_ids], dim=-1)
28
+ else:
29
+ # If no history, just use the new user input
30
+ bot_input_ids = new_user_input_ids
31
 
32
  # Generate a response from the model
33
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, top_k=50, top_p=0.95, temperature=0.7)
 
36
  chat_history_ids = chat_history_ids[:, bot_input_ids.shape[-1]:] # only take the latest generated tokens
37
  bot_output = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True)
38
 
39
+ # Update session state history with the new tokens (flattened)
40
  st.session_state['history'] = chat_history_ids[0].tolist()
41
 
42
  return bot_output