Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -19,8 +19,15 @@ def generate_response(input_text):
|
|
19 |
# Encode the new user input, add end of string token
|
20 |
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
|
21 |
|
22 |
-
#
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Generate a response from the model
|
26 |
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, top_k=50, top_p=0.95, temperature=0.7)
|
@@ -29,7 +36,7 @@ def generate_response(input_text):
|
|
29 |
chat_history_ids = chat_history_ids[:, bot_input_ids.shape[-1]:] # only take the latest generated tokens
|
30 |
bot_output = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True)
|
31 |
|
32 |
-
# Update session state history with the new tokens
|
33 |
st.session_state['history'] = chat_history_ids[0].tolist()
|
34 |
|
35 |
return bot_output
|
|
|
19 |
# Encode the new user input, add end of string token
|
20 |
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
|
21 |
|
22 |
+
# If there is conversation history, append the new input to it
|
23 |
+
if st.session_state['history']:
|
24 |
+
# Convert history to a 2D tensor (batch_size x seq_len)
|
25 |
+
history_tensor = torch.tensor(st.session_state['history']).unsqueeze(0).to(device)
|
26 |
+
# Concatenate history with the new input
|
27 |
+
bot_input_ids = torch.cat([history_tensor, new_user_input_ids], dim=-1)
|
28 |
+
else:
|
29 |
+
# If no history, just use the new user input
|
30 |
+
bot_input_ids = new_user_input_ids
|
31 |
|
32 |
# Generate a response from the model
|
33 |
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, top_k=50, top_p=0.95, temperature=0.7)
|
|
|
36 |
chat_history_ids = chat_history_ids[:, bot_input_ids.shape[-1]:] # only take the latest generated tokens
|
37 |
bot_output = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True)
|
38 |
|
39 |
+
# Update session state history with the new tokens (flattened)
|
40 |
st.session_state['history'] = chat_history_ids[0].tolist()
|
41 |
|
42 |
return bot_output
|