RacoGPT / app.py
Francesco26061993's picture
Partial bot response on user's stop generation
4efd2d6
raw
history blame
5.12 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from streamlit import session_state as ss
import os
import login # Importa il file login.py che hai creato
# Stile del chat input e sfondo della pagina
st.markdown("""
<style>
section[data-testid="stTextInput"] input {
color: black !important;
background-color: #F0F2F6 !important;
font-size: 16px;
border-radius: 10px;
padding: 10px;
}
.main {
background-color: #0A0A1A;
color: #FFFFFF;
}
.stChatMessage div[data-baseweb="block"] {
background-color: rgba(255, 255, 255, 0.1) !important;
color: #FFFFFF !important;
border-radius: 10px !important;
}
</style>
""", unsafe_allow_html=True)
# Inizializza lo stato di login se non esiste
if "is_logged_in" not in st.session_state:
st.session_state["is_logged_in"] = False
# Mostra la pagina di login solo se l'utente non è loggato
if not st.session_state["is_logged_in"]:
login.login_page()
st.stop()
# Recupera le secrets da Hugging Face
model_repo = st.secrets["MODEL_REPO"]
hf_token = st.secrets["HF_TOKEN"]
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token)
model.config.use_cache = True
return tokenizer, model
# Funzione per generare una risposta in tempo reale con supporto per l'interruzione
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
response_text = ""
response_placeholder = st.empty()
# Genera un token alla volta e aggiorna il testo in response_text
for i in range(max_length):
if ss.get("stop_generation", False):
break # Interrompe il ciclo se l'utente ha premuto "stop"
output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
new_token_id = output[:, -1].item()
new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)
response_text += new_token
response_placeholder.markdown(f"RacoGPT: {response_text}", unsafe_allow_html=True)
# Salva il testo parziale in session_state per preservarlo in caso di interruzione
ss["response_text_partial"] = response_text
# Aggiungi il nuovo token alla sequenza di input
input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1)
# Interrompe se il token generato è <|endoftext|> o eos_token_id
if new_token_id == tokenizer.eos_token_id:
break
# Reimposta lo stato di "stop"
ss.stop_generation = False
return response_text
# Inizializza lo stato della sessione
if 'is_chat_input_disabled' not in ss:
ss.is_chat_input_disabled = False
if 'msg' not in ss:
ss.msg = []
if 'chat_history' not in ss:
ss.chat_history = None
if 'stop_generation' not in ss:
ss.stop_generation = False
# Carica il modello e tokenizer
tokenizer, model = load_model()
# Mostra la cronologia dei messaggi con le label personalizzate
for message in ss.msg:
if message["role"] == "user":
with st.chat_message("user"):
st.markdown(f"Tu: {message['content']}")
elif message["role"] == "RacoGPT":
with st.chat_message("RacoGPT"):
st.markdown(f"RacoGPT: {message['content']}")
# Contenitore per gestire la mutua esclusione tra input e pulsante di stop
input_container = st.empty()
if not ss.is_chat_input_disabled:
# Mostra la barra di input per inviare il messaggio
with input_container:
prompt = st.chat_input("Scrivi il tuo messaggio...")
if prompt:
ss.msg.append({"role": "user", "content": prompt})
with st.chat_message("user"):
ss.is_chat_input_disabled = True
st.markdown(f"Tu: {prompt}")
st.rerun()
else:
# Mostra il pulsante di "Stop Generazione" al posto della barra di input
with input_container:
if st.button("🛑 Stop Generazione", key="stop_button"):
ss.stop_generation = True # Interrompe la generazione impostando il flag
# Genera la risposta del bot con digitazione in tempo reale
with st.spinner("RacoGPT sta generando una risposta..."):
response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model)
# Usa il testo parziale se presente
final_response = response or ss.get("response_text_partial", "")
# Aggiungi la risposta finale nella cronologia dei messaggi
ss.msg.append({"role": "RacoGPT", "content": final_response})
with st.chat_message("RacoGPT"):
st.markdown(f"RacoGPT: {final_response}")
# Pulisce il testo parziale dalla sessione e riabilita l'input
ss.pop("response_text_partial", None)
ss.is_chat_input_disabled = False
# Rerun per aggiornare l'interfaccia
st.rerun()