OniXinO commited on
Commit
56843e1
·
1 Parent(s): 75eb9ca

4та спроба

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -13,26 +13,24 @@ st.title("Український Чат-бот")
13
  if "history" not in st.session_state:
14
  st.session_state.history = []
15
 
16
- user_input = st.text_input("Ви:", "")
17
 
18
  tokenizer, model = load_model()
19
 
20
- if st.button("Надіслати") or st.session_state.get("enter_pressed", False):
21
- st.session_state.enter_pressed = False
22
  if user_input:
23
- inputs = tokenizer(st.session_state.history + [user_input], return_tensors="pt")
24
  with torch.no_grad():
25
  outputs = model.generate(**inputs, max_length=100)
26
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
  st.session_state.history.extend([user_input, response])
 
 
 
 
28
 
29
  if st.session_state.history:
30
  for i in range(0, len(st.session_state.history), 2):
31
  st.write(f"Ви: {st.session_state.history[i]}")
32
  if i + 1 < len(st.session_state.history):
33
- st.write(f"Бот: {st.session_state.history[i+1]}")
34
-
35
- def set_enter_pressed():
36
- st.session_state.enter_pressed = True
37
-
38
- st.text_input("Ви:", key="user_input_enter", on_change=set_enter_pressed)
 
13
  if "history" not in st.session_state:
14
  st.session_state.history = []
15
 
16
+ user_input = st.text_input("Ви:", key="user_input_enter")
17
 
18
  tokenizer, model = load_model()
19
 
20
+ def send_message():
 
21
  if user_input:
22
+ inputs = tokenizer(st.session_state.history + [user_input], return_tensors="pt", padding=True, truncation=True)
23
  with torch.no_grad():
24
  outputs = model.generate(**inputs, max_length=100)
25
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
  st.session_state.history.extend([user_input, response])
27
+ st.session_state.user_input_enter = "" # clear the input field after sending
28
+
29
+ if st.button("Надіслати") or st.session_state.get("user_input_enter", "") != "":
30
+ send_message()
31
 
32
  if st.session_state.history:
33
  for i in range(0, len(st.session_state.history), 2):
34
  st.write(f"Ви: {st.session_state.history[i]}")
35
  if i + 1 < len(st.session_state.history):
36
+ st.write(f"Бот: {st.session_state.history[i+1]}")