File size: 1,748 Bytes
a96c72f
c585309
75eb9ca
a96c72f
6d7b830
 
e22ba0b
c585309
75eb9ca
a96c72f
6d7b830
a96c72f
6d7b830
 
a96c72f
9000ced
 
a96c72f
75eb9ca
 
56843e1
9000ced
 
75eb9ca
 
 
9000ced
1d45e7e
 
9000ced
 
e22ba0b
9000ced
b434f91
9000ced
b434f91
 
 
6d7b830
 
75eb9ca
 
 
56843e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left", use_fast=False)
    model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
    return tokenizer, model

st.title("Український Чат-бот")

if "history" not in st.session_state:
    st.session_state.history = []

if "user_input" not in st.session_state:
    st.session_state.user_input = ""

tokenizer, model = load_model()

def send_message():
    if st.session_state.user_input:
        inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=100)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        st.session_state.history.extend([st.session_state.user_input, response])
        st.session_state.user_input = "" # clear the stored user input
        st.session_state.temp_user_input = "" # clear the text input field

def update_user_input():
    st.session_state.user_input = st.session_state.temp_user_input

st.text_input("Ви:", key="temp_user_input", on_change=update_user_input)

if st.button("Надіслати") or st.session_state.get("user_input", "") != "":
    if st.session_state.get("user_input", "") != "":
        send_message()

if st.session_state.history:
    for i in range(0, len(st.session_state.history), 2):
        st.write(f"Ви: {st.session_state.history[i]}")
        if i + 1 < len(st.session_state.history):
            st.write(f"Бот: {st.session_state.history[i+1]}")