Spaces:
Sleeping
Sleeping
File size: 5,839 Bytes
47793bc 1efb174 47793bc daa2868 47793bc daa2868 47793bc daa2868 47793bc daa2868 47793bc daa2868 47793bc 1efb174 47793bc daa2868 47793bc daa2868 47793bc daa2868 47793bc daa2868 47793bc daa2868 47793bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from streamlit import session_state as ss
import os
import login # Importa il file login.py che hai creato
st.markdown("""
<style>
/* Cambia il colore del testo all'interno del chat input */
section[data-testid="stTextInput"] input {
color: black !important; /* Cambia 'black' al colore che preferisci, es: #FFFFFF per bianco */
background-color: #F0F2F6 !important; /* Cambia il colore dello sfondo del chat input */
font-size: 16px; /* Cambia la dimensione del testo se necessario */
border-radius: 10px; /* Arrotonda i bordi del chat input */
padding: 10px; /* Aggiungi spazio interno per rendere l'aspetto più pulito */
}
</style>
""", unsafe_allow_html=True)
# Inizializza lo stato di login se non esiste
if "is_logged_in" not in st.session_state:
st.session_state["is_logged_in"] = False
# Mostra la pagina di login solo se l'utente non è loggato
if not st.session_state["is_logged_in"]:
login.login_page() # Mostra la pagina di login e blocca il caricamento della chat
st.stop()
# Recupera le secrets da Hugging Face
model_repo = st.secrets["MODEL_REPO"] # Repository del modello di base
hf_token = st.secrets["HF_TOKEN"] # Token Hugging Face
# Carica il modello di base con caching
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token)
model.config.use_cache = True
return tokenizer, model
# Funzione per generare una risposta in tempo reale con supporto per l'interruzione
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
response_text = ""
response_placeholder = st.empty() # Placeholder per mostrare la risposta progressiva
# Genera un token alla volta e aggiorna il placeholder
for i in range(max_length):
if ss.get("stop_generation", False):
break # Interrompe il ciclo se l'utente ha premuto "stop"
output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
new_token_id = output[:, -1].item()
new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)
response_text += new_token
response_placeholder.markdown(f"**RacoGPT:** {response_text}") # Aggiorna il testo progressivo
# Aggiungi il nuovo token alla sequenza di input
input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1)
# Interrompe se il token generato è <|endoftext|> o eos_token_id
if new_token_id == tokenizer.eos_token_id:
break
# Reimposta lo stato di "stop"
ss.stop_generation = False
return response_text
# Inizializza lo stato della sessione
if 'is_chat_input_disabled' not in ss:
ss.is_chat_input_disabled = False
if 'msg' not in ss:
ss.msg = []
if 'chat_history' not in ss:
ss.chat_history = None
if 'stop_generation' not in ss:
ss.stop_generation = False
# Carica il modello e tokenizer
tokenizer, model = load_model()
# Aggiungi il CSS per personalizzare il colore del testo e lo sfondo del chat input
st.markdown("""
<style>
/* Sfondo e colore del testo del chat input */
section[data-testid="stTextInput"] input {
color: black !important;
background-color: #F0F2F6 !important;
border-radius: 10px !important;
}
/* Sfondo scuro per l'intera pagina */
.main {
background-color: #0A0A1A;
color: #FFFFFF;
}
/* Modifica sfondo e colore dei messaggi di chat */
.stChatMessage div[data-baseweb="block"] {
background-color: rgba(255, 255, 255, 0.1) !important;
color: #FFFFFF !important;
border-radius: 10px !important;
}
</style>
""", unsafe_allow_html=True)
# Mostra la cronologia dei messaggi con le label personalizzate
for message in ss.msg:
if message["role"] == "user":
with st.chat_message("user"):
st.markdown(f"**Tu:** {message['content']}")
elif message["role"] == "RacoGPT":
with st.chat_message("RacoGPT"):
st.markdown(f"**RacoGPT:** {message['content']}")
# Contenitore per gestire la mutua esclusione tra input e pulsante di stop
input_container = st.empty()
if not ss.is_chat_input_disabled:
# Mostra la barra di input per inviare il messaggio
with input_container:
prompt = st.chat_input("Scrivi il tuo messaggio...")
if prompt:
# Salva il messaggio dell'utente e disabilita l'input
ss.msg.append({"role": "user", "content": prompt})
with st.chat_message("user"):
ss.is_chat_input_disabled = True
st.markdown(f"**Tu:** {prompt}")
st.rerun()
else:
# Mostra il pulsante di "Stop Generazione" al posto della barra di input
with input_container:
if st.button("🛑 Stop Generazione", key="stop_button"):
ss.stop_generation = True # Interrompe la generazione impostando il flag
# Genera la risposta del bot con digitazione in tempo reale
with st.spinner("RacoGPT sta generando una risposta..."):
response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model)
# Mostra il messaggio finale del bot dopo che la risposta è completata o interrotta
ss.msg.append({"role": "RacoGPT", "content": response})
with st.chat_message("RacoGPT"):
st.markdown(f"**RacoGPT:** {response}")
ss.is_chat_input_disabled = False
# Rerun per aggiornare l'interfaccia
st.rerun() |