File size: 5,839 Bytes
47793bc
 
 
 
 
 
 
1efb174
 
 
 
 
 
 
 
 
 
 
 
 
47793bc
 
 
 
 
 
 
 
 
 
daa2868
 
47793bc
 
 
 
 
 
 
 
 
daa2868
47793bc
 
 
 
 
 
 
 
 
daa2868
 
 
47793bc
 
 
 
 
 
 
 
 
 
 
 
 
 
daa2868
 
47793bc
 
 
 
 
 
 
 
 
 
 
 
daa2868
 
 
47793bc
 
 
1efb174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47793bc
 
 
 
 
 
 
 
 
daa2868
 
 
 
 
 
 
47793bc
daa2868
 
47793bc
 
 
 
daa2868
 
 
 
 
 
47793bc
 
 
daa2868
47793bc
daa2868
47793bc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from streamlit import session_state as ss
import os
import login  # Importa il file login.py che hai creato

st.markdown("""
    <style>
    /* Cambia il colore del testo all'interno del chat input */
    section[data-testid="stTextInput"] input {
        color: black !important; /* Cambia 'black' al colore che preferisci, es: #FFFFFF per bianco */
        background-color: #F0F2F6 !important; /* Cambia il colore dello sfondo del chat input */
        font-size: 16px; /* Cambia la dimensione del testo se necessario */
        border-radius: 10px; /* Arrotonda i bordi del chat input */
        padding: 10px; /* Aggiungi spazio interno per rendere l'aspetto più pulito */
    }
    </style>
""", unsafe_allow_html=True)

# Inizializza lo stato di login se non esiste
if "is_logged_in" not in st.session_state:
    st.session_state["is_logged_in"] = False

# Mostra la pagina di login solo se l'utente non è loggato
if not st.session_state["is_logged_in"]:
    login.login_page()  # Mostra la pagina di login e blocca il caricamento della chat
    st.stop()

# Recupera le secrets da Hugging Face
model_repo = st.secrets["MODEL_REPO"]  # Repository del modello di base
hf_token = st.secrets["HF_TOKEN"]  # Token Hugging Face

# Carica il modello di base con caching
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token)
    model.config.use_cache = True
    return tokenizer, model

# Funzione per generare una risposta in tempo reale con supporto per l'interruzione
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
    eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
    input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
    response_text = ""

    response_placeholder = st.empty()  # Placeholder per mostrare la risposta progressiva

    # Genera un token alla volta e aggiorna il placeholder
    for i in range(max_length):
        if ss.get("stop_generation", False):
            break  # Interrompe il ciclo se l'utente ha premuto "stop"

        output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
        new_token_id = output[:, -1].item()
        new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)

        response_text += new_token
        response_placeholder.markdown(f"**RacoGPT:** {response_text}")  # Aggiorna il testo progressivo

        # Aggiungi il nuovo token alla sequenza di input
        input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1)

        # Interrompe se il token generato è <|endoftext|> o eos_token_id
        if new_token_id == tokenizer.eos_token_id:
            break

    # Reimposta lo stato di "stop"
    ss.stop_generation = False
    return response_text

# Inizializza lo stato della sessione
if 'is_chat_input_disabled' not in ss:
    ss.is_chat_input_disabled = False

if 'msg' not in ss:
    ss.msg = []

if 'chat_history' not in ss:
    ss.chat_history = None

if 'stop_generation' not in ss:
    ss.stop_generation = False

# Carica il modello e tokenizer
tokenizer, model = load_model()

# Aggiungi il CSS per personalizzare il colore del testo e lo sfondo del chat input
st.markdown("""
    <style>
    /* Sfondo e colore del testo del chat input */
    section[data-testid="stTextInput"] input {
        color: black !important;
        background-color: #F0F2F6 !important;
        border-radius: 10px !important;
    }
    
    /* Sfondo scuro per l'intera pagina */
    .main {
        background-color: #0A0A1A;
        color: #FFFFFF;
    }

    /* Modifica sfondo e colore dei messaggi di chat */
    .stChatMessage div[data-baseweb="block"] {
        background-color: rgba(255, 255, 255, 0.1) !important;
        color: #FFFFFF !important;
        border-radius: 10px !important;
    }
    </style>
""", unsafe_allow_html=True)

# Mostra la cronologia dei messaggi con le label personalizzate
for message in ss.msg:
    if message["role"] == "user":
        with st.chat_message("user"):
            st.markdown(f"**Tu:** {message['content']}")
    elif message["role"] == "RacoGPT":
        with st.chat_message("RacoGPT"):
            st.markdown(f"**RacoGPT:** {message['content']}")

# Contenitore per gestire la mutua esclusione tra input e pulsante di stop
input_container = st.empty()

if not ss.is_chat_input_disabled:
    # Mostra la barra di input per inviare il messaggio
    with input_container:
        prompt = st.chat_input("Scrivi il tuo messaggio...")

    if prompt:
        # Salva il messaggio dell'utente e disabilita l'input
        ss.msg.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            ss.is_chat_input_disabled = True
            st.markdown(f"**Tu:** {prompt}")
            st.rerun()
else:
    # Mostra il pulsante di "Stop Generazione" al posto della barra di input
    with input_container:
        if st.button("🛑 Stop Generazione", key="stop_button"):
            ss.stop_generation = True  # Interrompe la generazione impostando il flag

    # Genera la risposta del bot con digitazione in tempo reale
    with st.spinner("RacoGPT sta generando una risposta..."):
        response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model)

    # Mostra il messaggio finale del bot dopo che la risposta è completata o interrotta
    ss.msg.append({"role": "RacoGPT", "content": response})
    with st.chat_message("RacoGPT"):
        st.markdown(f"**RacoGPT:** {response}")
    ss.is_chat_input_disabled = False

    # Rerun per aggiornare l'interfaccia
    st.rerun()