Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from streamlit import session_state as ss | |
import os | |
import login # Importa il file login.py che hai creato | |
st.markdown(""" | |
<style> | |
/* Cambia il colore del testo all'interno del chat input */ | |
section[data-testid="stTextInput"] input { | |
color: black !important; /* Cambia 'black' al colore che preferisci, es: #FFFFFF per bianco */ | |
background-color: #F0F2F6 !important; /* Cambia il colore dello sfondo del chat input */ | |
font-size: 16px; /* Cambia la dimensione del testo se necessario */ | |
border-radius: 10px; /* Arrotonda i bordi del chat input */ | |
padding: 10px; /* Aggiungi spazio interno per rendere l'aspetto più pulito */ | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Inizializza lo stato di login se non esiste | |
if "is_logged_in" not in st.session_state: | |
st.session_state["is_logged_in"] = False | |
# Mostra la pagina di login solo se l'utente non è loggato | |
if not st.session_state["is_logged_in"]: | |
login.login_page() # Mostra la pagina di login e blocca il caricamento della chat | |
st.stop() | |
# Recupera le secrets da Hugging Face | |
model_repo = st.secrets["MODEL_REPO"] # Repository del modello di base | |
hf_token = st.secrets["HF_TOKEN"] # Token Hugging Face | |
# Carica il modello di base con caching | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token) | |
model.config.use_cache = True | |
return tokenizer, model | |
# Funzione per generare una risposta in tempo reale con supporto per l'interruzione | |
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512): | |
eos_token = tokenizer.eos_token if tokenizer.eos_token else "" | |
input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt") | |
response_text = "" | |
response_placeholder = st.empty() # Placeholder per mostrare la risposta progressiva | |
# Genera un token alla volta e aggiorna il placeholder | |
for i in range(max_length): | |
if ss.get("stop_generation", False): | |
break # Interrompe il ciclo se l'utente ha premuto "stop" | |
output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True) | |
new_token_id = output[:, -1].item() | |
new_token = tokenizer.decode([new_token_id], skip_special_tokens=True) | |
response_text += new_token | |
response_placeholder.markdown(f"**RacoGPT:** {response_text}") # Aggiorna il testo progressivo | |
# Aggiungi il nuovo token alla sequenza di input | |
input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1) | |
# Interrompe se il token generato è <|endoftext|> o eos_token_id | |
if new_token_id == tokenizer.eos_token_id: | |
break | |
# Reimposta lo stato di "stop" | |
ss.stop_generation = False | |
return response_text | |
# Inizializza lo stato della sessione | |
if 'is_chat_input_disabled' not in ss: | |
ss.is_chat_input_disabled = False | |
if 'msg' not in ss: | |
ss.msg = [] | |
if 'chat_history' not in ss: | |
ss.chat_history = None | |
if 'stop_generation' not in ss: | |
ss.stop_generation = False | |
# Carica il modello e tokenizer | |
tokenizer, model = load_model() | |
# Aggiungi il CSS per personalizzare il colore del testo e lo sfondo del chat input | |
st.markdown(""" | |
<style> | |
/* Sfondo e colore del testo del chat input */ | |
section[data-testid="stTextInput"] input { | |
color: black !important; | |
background-color: #F0F2F6 !important; | |
border-radius: 10px !important; | |
} | |
/* Sfondo scuro per l'intera pagina */ | |
.main { | |
background-color: #0A0A1A; | |
color: #FFFFFF; | |
} | |
/* Modifica sfondo e colore dei messaggi di chat */ | |
.stChatMessage div[data-baseweb="block"] { | |
background-color: rgba(255, 255, 255, 0.1) !important; | |
color: #FFFFFF !important; | |
border-radius: 10px !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Mostra la cronologia dei messaggi con le label personalizzate | |
for message in ss.msg: | |
if message["role"] == "user": | |
with st.chat_message("user"): | |
st.markdown(f"**Tu:** {message['content']}") | |
elif message["role"] == "RacoGPT": | |
with st.chat_message("RacoGPT"): | |
st.markdown(f"**RacoGPT:** {message['content']}") | |
# Contenitore per gestire la mutua esclusione tra input e pulsante di stop | |
input_container = st.empty() | |
if not ss.is_chat_input_disabled: | |
# Mostra la barra di input per inviare il messaggio | |
with input_container: | |
prompt = st.chat_input("Scrivi il tuo messaggio...") | |
if prompt: | |
# Salva il messaggio dell'utente e disabilita l'input | |
ss.msg.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
ss.is_chat_input_disabled = True | |
st.markdown(f"**Tu:** {prompt}") | |
st.rerun() | |
else: | |
# Mostra il pulsante di "Stop Generazione" al posto della barra di input | |
with input_container: | |
if st.button("🛑 Stop Generazione", key="stop_button"): | |
ss.stop_generation = True # Interrompe la generazione impostando il flag | |
# Genera la risposta del bot con digitazione in tempo reale | |
with st.spinner("RacoGPT sta generando una risposta..."): | |
response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model) | |
# Mostra il messaggio finale del bot dopo che la risposta è completata o interrotta | |
ss.msg.append({"role": "RacoGPT", "content": response}) | |
with st.chat_message("RacoGPT"): | |
st.markdown(f"**RacoGPT:** {response}") | |
ss.is_chat_input_disabled = False | |
# Rerun per aggiornare l'interfaccia | |
st.rerun() |