File size: 1,858 Bytes
a96c72f
c585309
75eb9ca
a96c72f
6d7b830
 
e22ba0b
c585309
75eb9ca
a96c72f
6d7b830
a96c72f
6d7b830
 
a96c72f
9000ced
 
a96c72f
75eb9ca
 
56843e1
9000ced
 
75eb9ca
 
 
9000ced
14e602c
 
9000ced
 
e22ba0b
9000ced
14e602c
 
9417eab
 
14e602c
 
 
6d7b830
 
75eb9ca
 
 
56843e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left", use_fast=False)
    model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
    return tokenizer, model

st.title("Український Чат-бот")

if "history" not in st.session_state:
    st.session_state.history = []

if "user_input" not in st.session_state:
    st.session_state.user_input = ""

tokenizer, model = load_model()

def send_message():
    if st.session_state.user_input:
        inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=100)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        st.session_state.history.extend([st.session_state.user_input, response])
        st.session_state.user_input = "" # Очищаємо збережений ввід
        st.session_state.temp_user_input = "" # Очищаємо текстове поле

def update_user_input():
    st.session_state.user_input = st.session_state.temp_user_input

# Колбек для обробки натискання кнопки "Надіслати"
def on_send_button_click():
    send_message()

st.text_input("Ви:", key="temp_user_input", on_change=update_user_input, on_submit=on_send_button_click)

st.button("Надіслати", on_click=on_send_button_click)

if st.session_state.history:
    for i in range(0, len(st.session_state.history), 2):
        st.write(f"Ви: {st.session_state.history[i]}")
        if i + 1 < len(st.session_state.history):
            st.write(f"Бот: {st.session_state.history[i+1]}")