Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from streamlit import session_state as ss | |
import os | |
import login # Importa il file login.py che hai creato | |
# Stile del chat input e sfondo della pagina | |
st.markdown(""" | |
<style> | |
section[data-testid="stTextInput"] input { | |
color: black !important; | |
background-color: #F0F2F6 !important; | |
font-size: 16px; | |
border-radius: 10px; | |
padding: 10px; | |
} | |
.main { | |
background-color: #0A0A1A; | |
color: #FFFFFF; | |
} | |
.stChatMessage div[data-baseweb="block"] { | |
background-color: rgba(255, 255, 255, 0.1) !important; | |
color: #FFFFFF !important; | |
border-radius: 10px !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Inizializza lo stato di login se non esiste | |
if "is_logged_in" not in st.session_state: | |
st.session_state["is_logged_in"] = False | |
# Mostra la pagina di login solo se l'utente non è loggato | |
if not st.session_state["is_logged_in"]: | |
login.login_page() | |
st.stop() | |
# Recupera le secrets da Hugging Face | |
model_repo = st.secrets["MODEL_REPO"] | |
hf_token = st.secrets["HF_TOKEN"] | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token) | |
model.config.use_cache = True | |
return tokenizer, model | |
# Funzione per generare una risposta in tempo reale con supporto per l'interruzione | |
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512): | |
eos_token = tokenizer.eos_token if tokenizer.eos_token else "" | |
input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt") | |
response_text = "" | |
response_placeholder = st.empty() | |
# Genera un token alla volta e aggiorna il testo in response_text | |
for i in range(max_length): | |
if ss.get("stop_generation", False): | |
break # Interrompe il ciclo se l'utente ha premuto "stop" | |
output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True) | |
new_token_id = output[:, -1].item() | |
new_token = tokenizer.decode([new_token_id], skip_special_tokens=True) | |
response_text += new_token | |
response_placeholder.markdown(f"RacoGPT: {response_text}", unsafe_allow_html=True) | |
# Salva il testo parziale in session_state per preservarlo in caso di interruzione | |
ss["response_text_partial"] = response_text | |
# Aggiungi il nuovo token alla sequenza di input | |
input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1) | |
# Interrompe se il token generato è <|endoftext|> o eos_token_id | |
if new_token_id == tokenizer.eos_token_id: | |
break | |
# Reimposta lo stato di "stop" | |
ss.stop_generation = False | |
return response_text | |
# Inizializza lo stato della sessione | |
if 'is_chat_input_disabled' not in ss: | |
ss.is_chat_input_disabled = False | |
if 'msg' not in ss: | |
ss.msg = [] | |
if 'chat_history' not in ss: | |
ss.chat_history = None | |
if 'stop_generation' not in ss: | |
ss.stop_generation = False | |
# Carica il modello e tokenizer | |
tokenizer, model = load_model() | |
# Mostra la cronologia dei messaggi con le label personalizzate | |
for message in ss.msg: | |
if message["role"] == "user": | |
with st.chat_message("user"): | |
st.markdown(f"Tu: {message['content']}") | |
elif message["role"] == "RacoGPT": | |
with st.chat_message("RacoGPT"): | |
st.markdown(f"RacoGPT: {message['content']}") | |
# Contenitore per gestire la mutua esclusione tra input e pulsante di stop | |
input_container = st.empty() | |
if not ss.is_chat_input_disabled: | |
# Mostra la barra di input per inviare il messaggio | |
with input_container: | |
prompt = st.chat_input("Scrivi il tuo messaggio...") | |
if prompt: | |
ss.msg.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
ss.is_chat_input_disabled = True | |
st.markdown(f"Tu: {prompt}") | |
st.rerun() | |
else: | |
# Mostra il pulsante di "Stop Generazione" al posto della barra di input | |
with input_container: | |
if st.button("🛑 Stop Generazione", key="stop_button"): | |
ss.stop_generation = True # Interrompe la generazione impostando il flag | |
# Genera la risposta del bot con digitazione in tempo reale | |
with st.spinner("RacoGPT sta generando una risposta..."): | |
response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model) | |
# Usa il testo parziale se presente | |
final_response = response or ss.get("response_text_partial", "") | |
# Aggiungi la risposta finale nella cronologia dei messaggi | |
ss.msg.append({"role": "RacoGPT", "content": final_response}) | |
with st.chat_message("RacoGPT"): | |
st.markdown(f"RacoGPT: {final_response}") | |
# Pulisce il testo parziale dalla sessione e riabilita l'input | |
ss.pop("response_text_partial", None) | |
ss.is_chat_input_disabled = False | |
# Rerun per aggiornare l'interfaccia | |
st.rerun() |