File size: 5,121 Bytes
47793bc
 
 
 
 
 
 
4efd2d6
1efb174
 
 
4efd2d6
 
 
 
 
 
 
 
 
 
 
 
 
 
1efb174
 
 
 
47793bc
 
 
 
 
 
4efd2d6
47793bc
 
 
4efd2d6
 
47793bc
 
 
 
 
 
 
 
daa2868
47793bc
 
 
 
 
4efd2d6
47793bc
4efd2d6
47793bc
daa2868
 
 
47793bc
 
 
 
 
4efd2d6
 
 
 
47793bc
 
 
 
 
 
 
 
daa2868
 
47793bc
 
 
 
 
 
 
 
 
 
 
 
daa2868
 
 
47793bc
 
 
 
 
 
 
4efd2d6
47793bc
 
4efd2d6
47793bc
daa2868
 
 
 
 
 
 
47793bc
daa2868
47793bc
 
 
4efd2d6
daa2868
 
 
 
 
 
47793bc
 
 
daa2868
47793bc
4efd2d6
 
 
 
 
47793bc
4efd2d6
 
 
 
47793bc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from streamlit import session_state as ss
import os
import login  # Importa il file login.py che hai creato

# Stile del chat input e sfondo della pagina
st.markdown("""
    <style>
    section[data-testid="stTextInput"] input {
        color: black !important;
        background-color: #F0F2F6 !important;
        font-size: 16px;
        border-radius: 10px;
        padding: 10px;
    }
    .main {
        background-color: #0A0A1A;
        color: #FFFFFF;
    }
    .stChatMessage div[data-baseweb="block"] {
        background-color: rgba(255, 255, 255, 0.1) !important;
        color: #FFFFFF !important;
        border-radius: 10px !important;
    }
    </style>
""", unsafe_allow_html=True)

# Inizializza lo stato di login se non esiste
if "is_logged_in" not in st.session_state:
    st.session_state["is_logged_in"] = False

# Mostra la pagina di login solo se l'utente non è loggato
if not st.session_state["is_logged_in"]:
    login.login_page()
    st.stop()

# Recupera le secrets da Hugging Face
model_repo = st.secrets["MODEL_REPO"]
hf_token = st.secrets["HF_TOKEN"]

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token)
    model.config.use_cache = True
    return tokenizer, model

# Funzione per generare una risposta in tempo reale con supporto per l'interruzione
def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
    eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
    input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
    response_text = ""

    response_placeholder = st.empty()

    # Genera un token alla volta e aggiorna il testo in response_text
    for i in range(max_length):
        if ss.get("stop_generation", False):
            break  # Interrompe il ciclo se l'utente ha premuto "stop"

        output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
        new_token_id = output[:, -1].item()
        new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)

        response_text += new_token
        response_placeholder.markdown(f"RacoGPT: {response_text}", unsafe_allow_html=True)

        # Salva il testo parziale in session_state per preservarlo in caso di interruzione
        ss["response_text_partial"] = response_text

        # Aggiungi il nuovo token alla sequenza di input
        input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1)

        # Interrompe se il token generato è <|endoftext|> o eos_token_id
        if new_token_id == tokenizer.eos_token_id:
            break

    # Reimposta lo stato di "stop"
    ss.stop_generation = False
    return response_text

# Inizializza lo stato della sessione
if 'is_chat_input_disabled' not in ss:
    ss.is_chat_input_disabled = False

if 'msg' not in ss:
    ss.msg = []

if 'chat_history' not in ss:
    ss.chat_history = None

if 'stop_generation' not in ss:
    ss.stop_generation = False

# Carica il modello e tokenizer
tokenizer, model = load_model()

# Mostra la cronologia dei messaggi con le label personalizzate
for message in ss.msg:
    if message["role"] == "user":
        with st.chat_message("user"):
            st.markdown(f"Tu: {message['content']}")
    elif message["role"] == "RacoGPT":
        with st.chat_message("RacoGPT"):
            st.markdown(f"RacoGPT: {message['content']}")

# Contenitore per gestire la mutua esclusione tra input e pulsante di stop
input_container = st.empty()

if not ss.is_chat_input_disabled:
    # Mostra la barra di input per inviare il messaggio
    with input_container:
        prompt = st.chat_input("Scrivi il tuo messaggio...")

    if prompt:
        ss.msg.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            ss.is_chat_input_disabled = True
            st.markdown(f"Tu: {prompt}")
            st.rerun()
else:
    # Mostra il pulsante di "Stop Generazione" al posto della barra di input
    with input_container:
        if st.button("🛑 Stop Generazione", key="stop_button"):
            ss.stop_generation = True  # Interrompe la generazione impostando il flag

    # Genera la risposta del bot con digitazione in tempo reale
    with st.spinner("RacoGPT sta generando una risposta..."):
        response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model)

    # Usa il testo parziale se presente
    final_response = response or ss.get("response_text_partial", "")

    # Aggiungi la risposta finale nella cronologia dei messaggi
    ss.msg.append({"role": "RacoGPT", "content": final_response})
    with st.chat_message("RacoGPT"):
        st.markdown(f"RacoGPT: {final_response}")

    # Pulisce il testo parziale dalla sessione e riabilita l'input
    ss.pop("response_text_partial", None)
    ss.is_chat_input_disabled = False

    # Rerun per aggiornare l'interfaccia
    st.rerun()