Spaces:
Sleeping
Sleeping
File size: 5,121 Bytes
47793bc 4efd2d6 1efb174 4efd2d6 1efb174 47793bc 4efd2d6 47793bc 4efd2d6 47793bc daa2868 47793bc 4efd2d6 47793bc 4efd2d6 47793bc daa2868 47793bc 4efd2d6 47793bc daa2868 47793bc daa2868 47793bc 4efd2d6 47793bc 4efd2d6 47793bc daa2868 47793bc daa2868 47793bc 4efd2d6 daa2868 47793bc daa2868 47793bc 4efd2d6 47793bc 4efd2d6 47793bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from streamlit import session_state as ss
import os
import login # Importa il file login.py che hai creato
# Stile del chat input e sfondo della pagina
st.markdown("""
<style>
section[data-testid="stTextInput"] input {
color: black !important;
background-color: #F0F2F6 !important;
font-size: 16px;
border-radius: 10px;
padding: 10px;
}
.main {
background-color: #0A0A1A;
color: #FFFFFF;
}
.stChatMessage div[data-baseweb="block"] {
background-color: rgba(255, 255, 255, 0.1) !important;
color: #FFFFFF !important;
border-radius: 10px !important;
}
</style>
""", unsafe_allow_html=True)
# Inizializza lo stato di login se non esiste
if "is_logged_in" not in st.session_state:
st.session_state["is_logged_in"] = False
# Mostra la pagina di login solo se l'utente non è loggato
if not st.session_state["is_logged_in"]:
login.login_page()
st.stop()
# Recupera le secrets da Hugging Face
model_repo = st.secrets["MODEL_REPO"]
hf_token = st.secrets["HF_TOKEN"]
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token)
model.config.use_cache = True
return tokenizer, model
# Funzione per generare una risposta in tempo reale con supporto per l'interruzione
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
response_text = ""
response_placeholder = st.empty()
# Genera un token alla volta e aggiorna il testo in response_text
for i in range(max_length):
if ss.get("stop_generation", False):
break # Interrompe il ciclo se l'utente ha premuto "stop"
output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
new_token_id = output[:, -1].item()
new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)
response_text += new_token
response_placeholder.markdown(f"RacoGPT: {response_text}", unsafe_allow_html=True)
# Salva il testo parziale in session_state per preservarlo in caso di interruzione
ss["response_text_partial"] = response_text
# Aggiungi il nuovo token alla sequenza di input
input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1)
# Interrompe se il token generato è <|endoftext|> o eos_token_id
if new_token_id == tokenizer.eos_token_id:
break
# Reimposta lo stato di "stop"
ss.stop_generation = False
return response_text
# Inizializza lo stato della sessione
if 'is_chat_input_disabled' not in ss:
ss.is_chat_input_disabled = False
if 'msg' not in ss:
ss.msg = []
if 'chat_history' not in ss:
ss.chat_history = None
if 'stop_generation' not in ss:
ss.stop_generation = False
# Carica il modello e tokenizer
tokenizer, model = load_model()
# Mostra la cronologia dei messaggi con le label personalizzate
for message in ss.msg:
if message["role"] == "user":
with st.chat_message("user"):
st.markdown(f"Tu: {message['content']}")
elif message["role"] == "RacoGPT":
with st.chat_message("RacoGPT"):
st.markdown(f"RacoGPT: {message['content']}")
# Contenitore per gestire la mutua esclusione tra input e pulsante di stop
input_container = st.empty()
if not ss.is_chat_input_disabled:
# Mostra la barra di input per inviare il messaggio
with input_container:
prompt = st.chat_input("Scrivi il tuo messaggio...")
if prompt:
ss.msg.append({"role": "user", "content": prompt})
with st.chat_message("user"):
ss.is_chat_input_disabled = True
st.markdown(f"Tu: {prompt}")
st.rerun()
else:
# Mostra il pulsante di "Stop Generazione" al posto della barra di input
with input_container:
if st.button("🛑 Stop Generazione", key="stop_button"):
ss.stop_generation = True # Interrompe la generazione impostando il flag
# Genera la risposta del bot con digitazione in tempo reale
with st.spinner("RacoGPT sta generando una risposta..."):
response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model)
# Usa il testo parziale se presente
final_response = response or ss.get("response_text_partial", "")
# Aggiungi la risposta finale nella cronologia dei messaggi
ss.msg.append({"role": "RacoGPT", "content": final_response})
with st.chat_message("RacoGPT"):
st.markdown(f"RacoGPT: {final_response}")
# Pulisce il testo parziale dalla sessione e riabilita l'input
ss.pop("response_text_partial", None)
ss.is_chat_input_disabled = False
# Rerun per aggiornare l'interfaccia
st.rerun() |