OniXinO
Перенесено st.session_state.temp_user_input = "" всередину функції send_message, після того як повідомлення було успішно оброблено та додано до історії. Тепер очищення текстового поля відбувається через оновлення st.session_state.temp_user_input у контексті функції, викликаної дією користувача (натисканням кнопки або зміною тексту).
1d45e7e
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left", use_fast=False) | |
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small") | |
return tokenizer, model | |
st.title("Український Чат-бот") | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "user_input" not in st.session_state: | |
st.session_state.user_input = "" | |
tokenizer, model = load_model() | |
def send_message(): | |
if st.session_state.user_input: | |
inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_length=100) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
st.session_state.history.extend([st.session_state.user_input, response]) | |
st.session_state.user_input = "" # clear the stored user input | |
st.session_state.temp_user_input = "" # clear the text input field | |
def update_user_input(): | |
st.session_state.user_input = st.session_state.temp_user_input | |
st.text_input("Ви:", key="temp_user_input", on_change=update_user_input) | |
if st.button("Надіслати") or st.session_state.get("user_input", "") != "": | |
if st.session_state.get("user_input", "") != "": | |
send_message() | |
if st.session_state.history: | |
for i in range(0, len(st.session_state.history), 2): | |
st.write(f"Ви: {st.session_state.history[i]}") | |
if i + 1 < len(st.session_state.history): | |
st.write(f"Бот: {st.session_state.history[i+1]}") |