File size: 1,748 Bytes
a96c72f c585309 75eb9ca a96c72f 6d7b830 e22ba0b c585309 75eb9ca a96c72f 6d7b830 a96c72f 6d7b830 a96c72f 9000ced a96c72f 75eb9ca 56843e1 9000ced 75eb9ca 9000ced 1d45e7e 9000ced e22ba0b 9000ced b434f91 9000ced b434f91 6d7b830 75eb9ca 56843e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", padding_side="left", use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
return tokenizer, model
st.title("Український Чат-бот")
if "history" not in st.session_state:
st.session_state.history = []
if "user_input" not in st.session_state:
st.session_state.user_input = ""
tokenizer, model = load_model()
def send_message():
if st.session_state.user_input:
inputs = tokenizer(st.session_state.history + [st.session_state.user_input], return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.session_state.history.extend([st.session_state.user_input, response])
st.session_state.user_input = "" # clear the stored user input
st.session_state.temp_user_input = "" # clear the text input field
def update_user_input():
st.session_state.user_input = st.session_state.temp_user_input
st.text_input("Ви:", key="temp_user_input", on_change=update_user_input)
if st.button("Надіслати") or st.session_state.get("user_input", "") != "":
if st.session_state.get("user_input", "") != "":
send_message()
if st.session_state.history:
for i in range(0, len(st.session_state.history), 2):
st.write(f"Ви: {st.session_state.history[i]}")
if i + 1 < len(st.session_state.history):
st.write(f"Бот: {st.session_state.history[i+1]}") |