RacoGPT / app.py
Francesco26061993's picture
Dark theme
1efb174
raw
history blame
5.84 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from streamlit import session_state as ss
import os
import login # Importa il file login.py che hai creato
st.markdown("""
<style>
/* Cambia il colore del testo all'interno del chat input */
section[data-testid="stTextInput"] input {
color: black !important; /* Cambia 'black' al colore che preferisci, es: #FFFFFF per bianco */
background-color: #F0F2F6 !important; /* Cambia il colore dello sfondo del chat input */
font-size: 16px; /* Cambia la dimensione del testo se necessario */
border-radius: 10px; /* Arrotonda i bordi del chat input */
padding: 10px; /* Aggiungi spazio interno per rendere l'aspetto più pulito */
}
</style>
""", unsafe_allow_html=True)
# Inizializza lo stato di login se non esiste
if "is_logged_in" not in st.session_state:
st.session_state["is_logged_in"] = False
# Mostra la pagina di login solo se l'utente non è loggato
if not st.session_state["is_logged_in"]:
login.login_page() # Mostra la pagina di login e blocca il caricamento della chat
st.stop()
# Recupera le secrets da Hugging Face
model_repo = st.secrets["MODEL_REPO"] # Repository del modello di base
hf_token = st.secrets["HF_TOKEN"] # Token Hugging Face
# Carica il modello di base con caching
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token)
model.config.use_cache = True
return tokenizer, model
# Funzione per generare una risposta in tempo reale con supporto per l'interruzione
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
response_text = ""
response_placeholder = st.empty() # Placeholder per mostrare la risposta progressiva
# Genera un token alla volta e aggiorna il placeholder
for i in range(max_length):
if ss.get("stop_generation", False):
break # Interrompe il ciclo se l'utente ha premuto "stop"
output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
new_token_id = output[:, -1].item()
new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)
response_text += new_token
response_placeholder.markdown(f"**RacoGPT:** {response_text}") # Aggiorna il testo progressivo
# Aggiungi il nuovo token alla sequenza di input
input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1)
# Interrompe se il token generato è <|endoftext|> o eos_token_id
if new_token_id == tokenizer.eos_token_id:
break
# Reimposta lo stato di "stop"
ss.stop_generation = False
return response_text
# Inizializza lo stato della sessione
if 'is_chat_input_disabled' not in ss:
ss.is_chat_input_disabled = False
if 'msg' not in ss:
ss.msg = []
if 'chat_history' not in ss:
ss.chat_history = None
if 'stop_generation' not in ss:
ss.stop_generation = False
# Carica il modello e tokenizer
tokenizer, model = load_model()
# Aggiungi il CSS per personalizzare il colore del testo e lo sfondo del chat input
st.markdown("""
<style>
/* Sfondo e colore del testo del chat input */
section[data-testid="stTextInput"] input {
color: black !important;
background-color: #F0F2F6 !important;
border-radius: 10px !important;
}
/* Sfondo scuro per l'intera pagina */
.main {
background-color: #0A0A1A;
color: #FFFFFF;
}
/* Modifica sfondo e colore dei messaggi di chat */
.stChatMessage div[data-baseweb="block"] {
background-color: rgba(255, 255, 255, 0.1) !important;
color: #FFFFFF !important;
border-radius: 10px !important;
}
</style>
""", unsafe_allow_html=True)
# Mostra la cronologia dei messaggi con le label personalizzate
for message in ss.msg:
if message["role"] == "user":
with st.chat_message("user"):
st.markdown(f"**Tu:** {message['content']}")
elif message["role"] == "RacoGPT":
with st.chat_message("RacoGPT"):
st.markdown(f"**RacoGPT:** {message['content']}")
# Contenitore per gestire la mutua esclusione tra input e pulsante di stop
input_container = st.empty()
if not ss.is_chat_input_disabled:
# Mostra la barra di input per inviare il messaggio
with input_container:
prompt = st.chat_input("Scrivi il tuo messaggio...")
if prompt:
# Salva il messaggio dell'utente e disabilita l'input
ss.msg.append({"role": "user", "content": prompt})
with st.chat_message("user"):
ss.is_chat_input_disabled = True
st.markdown(f"**Tu:** {prompt}")
st.rerun()
else:
# Mostra il pulsante di "Stop Generazione" al posto della barra di input
with input_container:
if st.button("🛑 Stop Generazione", key="stop_button"):
ss.stop_generation = True # Interrompe la generazione impostando il flag
# Genera la risposta del bot con digitazione in tempo reale
with st.spinner("RacoGPT sta generando una risposta..."):
response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model)
# Mostra il messaggio finale del bot dopo che la risposta è completata o interrotta
ss.msg.append({"role": "RacoGPT", "content": response})
with st.chat_message("RacoGPT"):
st.markdown(f"**RacoGPT:** {response}")
ss.is_chat_input_disabled = False
# Rerun per aggiornare l'interfaccia
st.rerun()