ajsbsd commited on
Commit
add83be
·
verified ·
1 Parent(s): 7666164

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -17
app.py CHANGED
@@ -123,28 +123,23 @@ def load_models():
123
 
124
 
125
  # --- Generate Response and Audio Function ---
126
- @spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
127
- def generate_response_and_audio(
128
- message: str, # Current user message
129
- history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
130
- ) -> tuple: # Returns (updated_history, audio_file_path)
131
- """
132
- Generates a text response from the loaded LLM and then converts it to audio
133
- using the loaded TTS model.
134
- """
135
  global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
136
 
137
- # Initialize all models if not already loaded
138
  if tokenizer is None or llm_model is None or tts_model is None:
139
  load_models()
140
 
141
- if tokenizer is None or llm_model is None: # Check LLM loading status
142
  history.append({"role": "user", "content": message})
143
  history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
144
  return history, None
145
 
 
 
 
146
  # --- 1. Generate Text Response (LLM) ---
147
- messages = history
148
  messages.append({"role": "user", "content": message})
149
 
150
  try:
@@ -153,17 +148,15 @@ def generate_response_and_audio(
153
  print(f"Error applying chat template: {e}")
154
  input_text = ""
155
  for item in history:
156
- if item["role"] == "user":
157
- input_text += f"User: {item['content']}\n"
158
- elif item["role"] == "assistant":
159
- input_text += f"Assistant: {item['content']}\n"
160
  input_text += f"User: {message}\nAssistant:"
161
 
 
162
  input_ids = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
163
  with torch.no_grad():
164
  output_ids = llm_model.generate(
165
  input_ids["input_ids"],
166
- attention_mask=input_ids["attention_mask"], # <-- Add this line
167
  max_new_tokens=MAX_NEW_TOKENS,
168
  do_sample=DO_SAMPLE,
169
  temperature=TEMPERATURE,
@@ -172,6 +165,48 @@ def generate_response_and_audio(
172
  pad_token_id=tokenizer.eos_token_id
173
  )
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # --- 2. Generate Audio from Response (TTS) ---
176
  audio_path = None
177
  if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
 
123
 
124
 
125
  # --- Generate Response and Audio Function ---
126
+ @spaces.GPU
127
+ def generate_response_and_audio(message: str, history: list) -> tuple:
 
 
 
 
 
 
 
128
  global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
129
 
 
130
  if tokenizer is None or llm_model is None or tts_model is None:
131
  load_models()
132
 
133
+ if tokenizer is None or llm_model is None:
134
  history.append({"role": "user", "content": message})
135
  history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
136
  return history, None
137
 
138
+ # Initialize generated_text early
139
+ generated_text = ""
140
+
141
  # --- 1. Generate Text Response (LLM) ---
142
+ messages = history.copy()
143
  messages.append({"role": "user", "content": message})
144
 
145
  try:
 
148
  print(f"Error applying chat template: {e}")
149
  input_text = ""
150
  for item in history:
151
+ input_text += f"{item['role'].capitalize()}: {item['content']}\n"
 
 
 
152
  input_text += f"User: {message}\nAssistant:"
153
 
154
+ try:
155
  input_ids = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
156
  with torch.no_grad():
157
  output_ids = llm_model.generate(
158
  input_ids["input_ids"],
159
+ attention_mask=input_ids["attention_mask"],
160
  max_new_tokens=MAX_NEW_TOKENS,
161
  do_sample=DO_SAMPLE,
162
  temperature=TEMPERATURE,
 
165
  pad_token_id=tokenizer.eos_token_id
166
  )
167
 
168
+ generated_token_ids = output_ids[0][input_ids["input_ids"].shape[-1]:]
169
+ generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
170
+
171
+ except Exception as e:
172
+ print(f"Error during LLM generation: {e}")
173
+ history.append({"role": "assistant", "content": "I encountered an error while generating a response."})
174
+ return history, None
175
+
176
+ # --- 2. Generate Audio from Response (TTS) ---
177
+ audio_path = None
178
+ if all([tts_processor, tts_model, tts_vocoder, speaker_embeddings]):
179
+ try:
180
+ device = llm_model.device if llm_model else 'cpu'
181
+ tts_model.to(device)
182
+ tts_vocoder.to(device)
183
+ speaker_embeddings = speaker_embeddings.to(device)
184
+
185
+ tts_inputs = tts_processor(
186
+ text=generated_text,
187
+ return_tensors="pt",
188
+ max_length=550,
189
+ truncation=True
190
+ ).to(device)
191
+
192
+ with torch.no_grad():
193
+ speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
194
+
195
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
196
+ audio_path = tmp_file.name
197
+ sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
198
+
199
+ print(f"Audio saved to: {audio_path}")
200
+
201
+ except Exception as e:
202
+ print(f"Error generating audio: {e}")
203
+ audio_path = None
204
+ else:
205
+ print("TTS components not fully loaded. Skipping audio generation.")
206
+
207
+ # --- 3. Update Chat History ---
208
+ history.append({"role": "assistant", "content": generated_text})
209
+ return history, audio_path
210
  # --- 2. Generate Audio from Response (TTS) ---
211
  audio_path = None
212
  if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None: