File size: 1,858 Bytes
a96c72f c585309 75eb9ca a96c72f 6d7b830 e22ba0b c585309 75eb9ca a96c72f 6d7b830 a96c72f 6d7b830 a96c72f 9000ced a96c72f 75eb9ca 56843e1 9000ced 75eb9ca 9000ced 14e602c 9000ced e22ba0b 9000ced 14e602c 9417eab 14e602c 6d7b830 75eb9ca 56843e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left", use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
return tokenizer, model
st.title("Український Чат-бот")
if "history" not in st.session_state:
st.session_state.history = []
if "user_input" not in st.session_state:
st.session_state.user_input = ""
tokenizer, model = load_model()
def send_message():
if st.session_state.user_input:
inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.session_state.history.extend([st.session_state.user_input, response])
st.session_state.user_input = "" # Очищаємо збережений ввід
st.session_state.temp_user_input = "" # Очищаємо текстове поле
def update_user_input():
st.session_state.user_input = st.session_state.temp_user_input
# Колбек для обробки натискання кнопки "Надіслати"
def on_send_button_click():
send_message()
st.text_input("Ви:", key="temp_user_input", on_change=update_user_input, on_submit=on_send_button_click)
st.button("Надіслати", on_click=on_send_button_click)
if st.session_state.history:
for i in range(0, len(st.session_state.history), 2):
st.write(f"Ви: {st.session_state.history[i]}")
if i + 1 < len(st.session_state.history):
st.write(f"Бот: {st.session_state.history[i+1]}") |