Francesco26061993 commited on
Commit
daa2868
·
1 Parent(s): 47793bc

Stop bot response generation

Browse files
Files changed (1) hide show
  1. app.py +28 -9
app.py CHANGED
@@ -15,8 +15,8 @@ if not st.session_state["is_logged_in"]:
15
  st.stop()
16
 
17
  # Recupera le secrets da Hugging Face
18
- model_repo = os.getenv("MODEL_REPO") # Repository del modello di base
19
- hf_token = os.getenv("HF_TOKEN") # Token Hugging Face
20
 
21
  # Carica il modello di base con caching
22
  @st.cache_resource
@@ -26,7 +26,7 @@ def load_model():
26
  model.config.use_cache = True
27
  return tokenizer, model
28
 
29
- # Funzione per generare una risposta in tempo reale
30
  def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
31
  eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
32
  input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
@@ -36,6 +36,9 @@ def generate_llama_response_stream(user_input, tokenizer, model, max_length=512)
36
 
37
  # Genera un token alla volta e aggiorna il placeholder
38
  for i in range(max_length):
 
 
 
39
  output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
40
  new_token_id = output[:, -1].item()
41
  new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)
@@ -50,6 +53,8 @@ def generate_llama_response_stream(user_input, tokenizer, model, max_length=512)
50
  if new_token_id == tokenizer.eos_token_id:
51
  break
52
 
 
 
53
  return response_text
54
 
55
  # Inizializza lo stato della sessione
@@ -62,6 +67,9 @@ if 'msg' not in ss:
62
  if 'chat_history' not in ss:
63
  ss.chat_history = None
64
 
 
 
 
65
  # Carica il modello e tokenizer
66
  tokenizer, model = load_model()
67
 
@@ -74,21 +82,32 @@ for message in ss.msg:
74
  with st.chat_message("RacoGPT"):
75
  st.markdown(f"**RacoGPT:** {message['content']}")
76
 
77
- # Gestione dell'input e disabilitazione
78
- if (prompt := st.chat_input("Scrivi il tuo messaggio...", disabled=ss.is_chat_input_disabled) or "").strip():
 
 
 
 
 
79
 
80
- # Salva il messaggio dell'utente e disabilita l'input
81
- if not ss.is_chat_input_disabled:
82
  ss.msg.append({"role": "user", "content": prompt})
83
  with st.chat_message("user"):
84
  ss.is_chat_input_disabled = True
85
  st.markdown(f"**Tu:** {prompt}")
 
 
 
 
 
 
86
 
87
  # Genera la risposta del bot con digitazione in tempo reale
88
  with st.spinner("RacoGPT sta generando una risposta..."):
89
- response = generate_llama_response_stream(prompt, tokenizer, model)
90
 
91
- # Mostra il messaggio finale del bot dopo che la risposta è completata
92
  ss.msg.append({"role": "RacoGPT", "content": response})
93
  with st.chat_message("RacoGPT"):
94
  st.markdown(f"**RacoGPT:** {response}")
 
15
  st.stop()
16
 
17
  # Recupera le secrets da Hugging Face
18
+ model_repo = st.secrets["MODEL_REPO"] # Repository del modello di base
19
+ hf_token = st.secrets["HF_TOKEN"] # Token Hugging Face
20
 
21
  # Carica il modello di base con caching
22
  @st.cache_resource
 
26
  model.config.use_cache = True
27
  return tokenizer, model
28
 
29
+ # Funzione per generare una risposta in tempo reale con supporto per l'interruzione
30
  def generate_llama_response_stream(user_input, tokenizer, model, max_length=512):
31
  eos_token = tokenizer.eos_token if tokenizer.eos_token else ""
32
  input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt")
 
36
 
37
  # Genera un token alla volta e aggiorna il placeholder
38
  for i in range(max_length):
39
+ if ss.get("stop_generation", False):
40
+ break # Interrompe il ciclo se l'utente ha premuto "stop"
41
+
42
  output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True)
43
  new_token_id = output[:, -1].item()
44
  new_token = tokenizer.decode([new_token_id], skip_special_tokens=True)
 
53
  if new_token_id == tokenizer.eos_token_id:
54
  break
55
 
56
+ # Reimposta lo stato di "stop"
57
+ ss.stop_generation = False
58
  return response_text
59
 
60
  # Inizializza lo stato della sessione
 
67
  if 'chat_history' not in ss:
68
  ss.chat_history = None
69
 
70
+ if 'stop_generation' not in ss:
71
+ ss.stop_generation = False
72
+
73
  # Carica il modello e tokenizer
74
  tokenizer, model = load_model()
75
 
 
82
  with st.chat_message("RacoGPT"):
83
  st.markdown(f"**RacoGPT:** {message['content']}")
84
 
85
+ # Contenitore per gestire la mutua esclusione tra input e pulsante di stop
86
+ input_container = st.empty()
87
+
88
+ if not ss.is_chat_input_disabled:
89
+ # Mostra la barra di input per inviare il messaggio
90
+ with input_container:
91
+ prompt = st.chat_input("Scrivi il tuo messaggio...")
92
 
93
+ if prompt:
94
+ # Salva il messaggio dell'utente e disabilita l'input
95
  ss.msg.append({"role": "user", "content": prompt})
96
  with st.chat_message("user"):
97
  ss.is_chat_input_disabled = True
98
  st.markdown(f"**Tu:** {prompt}")
99
+ st.rerun()
100
+ else:
101
+ # Mostra il pulsante di "Stop Generazione" al posto della barra di input
102
+ with input_container:
103
+ if st.button("🛑 Stop Generazione", key="stop_button"):
104
+ ss.stop_generation = True # Interrompe la generazione impostando il flag
105
 
106
  # Genera la risposta del bot con digitazione in tempo reale
107
  with st.spinner("RacoGPT sta generando una risposta..."):
108
+ response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model)
109
 
110
+ # Mostra il messaggio finale del bot dopo che la risposta è completata o interrotta
111
  ss.msg.append({"role": "RacoGPT", "content": response})
112
  with st.chat_message("RacoGPT"):
113
  st.markdown(f"**RacoGPT:** {response}")