sagar007 commited on
Commit
aa6ca85
ยท
verified ยท
1 Parent(s): 6560c55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -389
app.py CHANGED
@@ -10,30 +10,31 @@ import subprocess
10
  import numpy as np
11
  from typing import List, Dict, Tuple, Any, Optional, Union
12
  from functools import lru_cache
13
- import asyncio
14
  import threading
15
- from concurrent.futures import ThreadPoolExecutor
16
  import warnings
17
  import traceback # For detailed error logging
18
  import re # For text cleaning
19
  import shutil # For checking sudo/file operations
20
  import html # For escaping HTML
21
  import sys # For sys.path manipulation
 
22
 
23
  # --- Configuration ---
24
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
25
  MAX_SEARCH_RESULTS = 5
26
  TTS_SAMPLE_RATE = 24000
27
- MAX_TTS_CHARS = 1000 # Max characters for a single TTS chunk
28
  MAX_NEW_TOKENS = 300
29
  TEMPERATURE = 0.7
30
  TOP_P = 0.95
31
- KOKORO_PATH = 'Kokoro-82M' # Relative path to TTS model directory
 
 
 
32
 
33
  # --- Initialization ---
34
- # Thread Pool Executor for blocking tasks
35
- executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
36
-
37
  # Suppress specific warnings
38
  warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
39
  warnings.filterwarnings("ignore", message="Backend 'inductor' is not available.")
@@ -48,27 +49,21 @@ try:
48
  llm_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
49
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
50
 
51
- if torch.cuda.is_available():
52
- llm_device = "cuda"
53
- torch_dtype = torch.float16
54
- device_map = "auto"
55
- print(f"[LLM Init] CUDA detected. Loading model with device_map='{device_map}', dtype={torch_dtype}")
56
- else:
57
- llm_device = "cpu"
58
- torch_dtype = torch.float32
59
- device_map = {"": "cpu"}
60
- print(f"[LLM Init] CUDA not found. Loading model on CPU with dtype={torch_dtype}")
61
 
62
  llm_model = AutoModelForCausalLM.from_pretrained(
63
  MODEL_NAME,
64
- device_map=device_map,
65
  low_cpu_mem_usage=True,
66
  torch_dtype=torch_dtype,
67
- # attn_implementation="flash_attention_2" # Optional
68
  )
69
- # Get the actual device map if using 'auto'
70
- effective_device_map = llm_model.hf_device_map if hasattr(llm_model, 'hf_device_map') else device_map
71
- print(f"[LLM Init] LLM loaded successfully. Device map: {effective_device_map}")
72
  llm_model.eval()
73
 
74
  except Exception as e:
@@ -80,6 +75,7 @@ except Exception as e:
80
 
81
 
82
  # --- TTS Initialization ---
 
83
  VOICE_CHOICES = {
84
  '๐Ÿ‡บ๐Ÿ‡ธ Female (Default)': 'af',
85
  '๐Ÿ‡บ๐Ÿ‡ธ Bella': 'af_bella',
@@ -91,18 +87,16 @@ tts_model: Optional[Any] = None
91
  voicepacks: Dict[str, Any] = {}
92
  tts_device = "cpu"
93
 
94
- # Helper for running subprocesses
95
  def _run_subprocess(cmd: List[str], check: bool = True, cwd: Optional[str] = None, timeout: int = 300) -> subprocess.CompletedProcess:
96
  """Runs a subprocess command, captures output, and handles errors."""
97
  print(f"Running command: {' '.join(cmd)}")
98
  try:
99
  result = subprocess.run(cmd, check=check, capture_output=True, text=True, cwd=cwd, timeout=timeout)
100
- # Only print output details if check failed or for specific successful commands
101
  if not check or result.returncode != 0:
102
- if result.stdout: print(f" Stdout: {result.stdout.strip()}")
103
- if result.stderr: print(f" Stderr: {result.stderr.strip()}")
104
  elif result.returncode == 0 and ('clone' in cmd or 'pull' in cmd or 'install' in cmd):
105
- print(f" Command successful.") # Concise success message
106
  return result
107
  except FileNotFoundError:
108
  print(f" Error: Command not found - {cmd[0]}")
@@ -116,189 +110,158 @@ def _run_subprocess(cmd: List[str], check: bool = True, cwd: Optional[str] = Non
116
  if e.stderr: print(f" Stderr: {e.stderr.strip()}")
117
  raise
118
 
119
- # TTS Setup Task (runs in background thread)
120
  def setup_tts_task():
121
  """Initializes Kokoro TTS model and dependencies."""
122
  global TTS_ENABLED, tts_model, voicepacks, tts_device
123
  print("[TTS Setup] Starting background initialization...")
124
 
125
- tts_device = "cuda" if torch.cuda.is_available() else "cpu"
126
- print(f"[TTS Setup] Target device: {tts_device}")
 
 
127
 
128
  can_sudo = shutil.which('sudo') is not None
129
  apt_cmd_prefix = ['sudo'] if can_sudo else []
130
- absolute_kokoro_path = os.path.abspath(KOKORO_PATH) # Use absolute path
131
 
132
  try:
133
- # 1. Clone Kokoro Repo if needed
134
  if not os.path.exists(absolute_kokoro_path):
135
- print(f"[TTS Setup] Cloning repository to {absolute_kokoro_path}...")
136
- try:
137
- _run_subprocess(['git', 'lfs', 'install', '--system', '--skip-repo'])
138
- except Exception as lfs_err:
139
- print(f"[TTS Setup] Warning: git lfs install failed: {lfs_err}. Continuing...")
140
- _run_subprocess(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M', absolute_kokoro_path])
141
- try:
142
- print("[TTS Setup] Running git lfs pull...")
143
- _run_subprocess(['git', 'lfs', 'pull'], cwd=absolute_kokoro_path)
144
- except Exception as lfs_pull_err:
145
- print(f"[TTS Setup] Warning: git lfs pull failed: {lfs_pull_err}")
146
  else:
147
- print(f"[TTS Setup] Directory {absolute_kokoro_path} already exists.")
148
- # Optional: Run git pull and lfs pull to update if needed
149
- # try:
150
- # print("[TTS Setup] Updating existing repo...")
151
- # _run_subprocess(['git', 'pull'], cwd=absolute_kokoro_path)
152
- # _run_subprocess(['git', 'lfs', 'pull'], cwd=absolute_kokoro_path)
153
- # except Exception as update_err:
154
- # print(f"[TTS Setup] Warning: Failed to update repo: {update_err}")
155
-
156
- # 2. Install espeak dependency
157
  print("[TTS Setup] Checking/Installing espeak...")
158
- try:
159
- # Run update quietly first
160
- _run_subprocess(apt_cmd_prefix + ['apt-get', 'update', '-qq'])
161
- # Try installing espeak-ng
162
- _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak-ng'])
163
- print("[TTS Setup] espeak-ng installed or already present.")
164
  except Exception:
165
- print("[TTS Setup] espeak-ng installation failed, trying espeak...")
166
- try:
167
- # Fallback to legacy espeak
168
- _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak'])
169
- print("[TTS Setup] espeak installed or already present.")
170
- except Exception as espeak_err:
171
- print(f"[TTS Setup] ERROR: Failed to install both espeak-ng and espeak: {espeak_err}. TTS disabled.")
172
- return # Cannot proceed
173
 
174
  # 3. Load Kokoro Model and Voices
175
  sys_path_updated = False
176
  if os.path.exists(absolute_kokoro_path):
177
- print(f"[TTS Setup] Checking contents of: {absolute_kokoro_path}")
178
- try:
179
- dir_contents = os.listdir(absolute_kokoro_path)
180
- print(f"[TTS Setup] Contents: {dir_contents}")
181
- if 'models.py' not in dir_contents or 'kokoro.py' not in dir_contents:
182
- print("[TTS Setup] Warning: Core Kokoro python files ('models.py', 'kokoro.py') might be missing!")
183
- except OSError as list_err:
184
- print(f"[TTS Setup] Warning: Could not list directory contents: {list_err}")
185
-
186
- # Add path temporarily for import
187
- if absolute_kokoro_path not in sys.path:
188
- sys.path.insert(0, absolute_kokoro_path) # Add to beginning
189
- sys_path_updated = True
190
- print(f"[TTS Setup] Temporarily added {absolute_kokoro_path} to sys.path.")
191
-
192
- try:
193
- print("[TTS Setup] Attempting to import Kokoro modules...")
194
- from models import build_model
195
- from kokoro import generate as generate_tts_internal
196
- print("[TTS Setup] Kokoro modules imported successfully.")
197
-
198
- # Make functions globally accessible IF NEEDED (alternative: pass them around)
199
- globals()['build_model'] = build_model
200
- globals()['generate_tts_internal'] = generate_tts_internal
201
-
202
- model_file = os.path.join(absolute_kokoro_path, 'kokoro-v0_19.pth')
203
- if not os.path.exists(model_file):
204
- print(f"[TTS Setup] ERROR: Model file {model_file} not found. TTS disabled.")
205
- return
206
-
207
- print(f"[TTS Setup] Loading TTS model from {model_file} onto {tts_device}...")
208
- tts_model = build_model(model_file, tts_device)
209
- tts_model.eval()
210
- print("[TTS Setup] TTS model loaded.")
211
-
212
- # Load voices
213
- loaded_voices = 0
214
- for voice_name, voice_id in VOICE_CHOICES.items():
215
- voice_file_path = os.path.join(absolute_kokoro_path, 'voices', f'{voice_id}.pt')
216
- if os.path.exists(voice_file_path):
217
- try:
218
- print(f"[TTS Setup] Loading voice: {voice_id} ({voice_name})")
219
- voicepacks[voice_id] = torch.load(voice_file_path, map_location=tts_device)
220
- loaded_voices += 1
221
- except Exception as e:
222
- print(f"[TTS Setup] Warning: Failed to load voice {voice_id}: {str(e)}")
223
- else:
224
- print(f"[TTS Setup] Info: Voice file {voice_file_path} not found.")
225
-
226
- if loaded_voices == 0:
227
- print("[TTS Setup] ERROR: No voicepacks could be loaded. TTS disabled.")
228
- tts_model = None # Free memory if no voices
229
- return
230
-
231
- TTS_ENABLED = True
232
- print(f"[TTS Setup] Initialization successful. {loaded_voices} voices loaded. TTS Enabled: {TTS_ENABLED}")
233
-
234
- # Catch the specific import error
235
- except ImportError as ie:
236
- print(f"[TTS Setup] ERROR: Failed to import Kokoro modules: {ie}.")
237
- print(f" Please ensure '{absolute_kokoro_path}' contains 'models.py' and 'kokoro.py'.")
238
- print(traceback.format_exc())
239
- except Exception as load_err:
240
- print(f"[TTS Setup] ERROR: Exception during TTS model/voice loading: {load_err}. TTS disabled.")
241
- print(traceback.format_exc())
242
- finally:
243
- # *** Crucial: Clean up sys.path ***
244
- if sys_path_updated:
245
- try:
246
- if sys.path[0] == absolute_kokoro_path:
247
- sys.path.pop(0)
248
- print(f"[TTS Setup] Removed {absolute_kokoro_path} from sys.path.")
249
- else:
250
- # It might have been removed elsewhere, or wasn't at index 0
251
- if absolute_kokoro_path in sys.path:
252
- sys.path.remove(absolute_kokoro_path)
253
- print(f"[TTS Setup] Removed {absolute_kokoro_path} from sys.path (was not index 0).")
254
- except Exception as cleanup_err:
255
- print(f"[TTS Setup] Warning: Error removing path from sys.path: {cleanup_err}")
256
  else:
257
  print(f"[TTS Setup] ERROR: Directory {absolute_kokoro_path} not found. TTS disabled.")
258
 
259
  except Exception as e:
260
  print(f"[TTS Setup] ERROR: Unexpected error during setup: {str(e)}")
261
  print(traceback.format_exc())
262
- TTS_ENABLED = False # Ensure disabled on any top-level error
263
- tts_model = None
264
- voicepacks.clear()
265
 
266
- # Start TTS setup in background
267
  print("Starting TTS setup thread...")
268
  tts_setup_thread = threading.Thread(target=setup_tts_task, daemon=True)
269
  tts_setup_thread.start()
270
 
271
 
272
- # --- Core Logic Functions ---
273
 
 
274
  @lru_cache(maxsize=128)
275
  def get_web_results_sync(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, Any]]:
276
  """Synchronous web search function with caching."""
 
277
  print(f"[Web Search] Searching (sync): '{query}' (max_results={max_results})")
278
  try:
279
  with DDGS() as ddgs:
280
  results = list(ddgs.text(query, max_results=max_results, safesearch='moderate', timelimit='y'))
281
  print(f"[Web Search] Found {len(results)} results.")
282
  formatted = [{
283
- "id": i + 1,
284
- "title": res.get("title", "No Title"),
285
- "snippet": res.get("body", "No Snippet"),
286
- "url": res.get("href", "#"),
287
  } for i, res in enumerate(results)]
288
  return formatted
289
  except Exception as e:
290
- print(f"[Web Search] Error: {e}")
291
- # Avoid printing full traceback repeatedly for common network errors maybe
292
- return []
293
 
 
294
  def format_llm_prompt(query: str, context: List[Dict[str, Any]]) -> str:
295
- """Formats the prompt for the LLM, including context and instructions."""
 
296
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
297
  context_str = "\n\n".join(
298
  [f"[{res['id']}] {html.escape(res['title'])}\n{html.escape(res['snippet'])}" for res in context]
299
  ) if context else "No relevant web context found."
300
-
301
- # Using a clear, structured prompt
302
  return f"""SYSTEM: You are a helpful AI assistant. Answer the user's query based *only* on the provided web search context. Cite sources using bracket notation like [1], [2]. If the context is insufficient, state that clearly. Use markdown for formatting. Do not add external information. Current Time: {current_time}
303
 
304
  CONTEXT:
@@ -308,67 +271,59 @@ CONTEXT:
308
 
309
  USER: {html.escape(query)}
310
 
311
- ASSISTANT:""" # Using ASSISTANT: marker might help some models
312
 
 
313
  def format_sources_html(web_results: List[Dict[str, Any]]) -> str:
314
  """Formats search results into HTML for display."""
315
- if not web_results:
316
- return "<div class='no-sources'>No sources found for this query.</div>"
317
  items_html = ""
318
  for res in web_results:
319
  title_safe = html.escape(res.get("title", "Source"))
320
  snippet_safe = html.escape(res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else ""))
321
- url = html.escape(res.get("url", "#")) # Escape URL too
322
- items_html += f"""
323
- <div class='source-item'>
324
- <div class='source-number'>[{res['id']}]</div>
325
- <div class='source-content'>
326
- <a href="{url}" target="_blank" class='source-title' title="{url}">{title_safe}</a>
327
- <div class='source-snippet'>{snippet_safe}</div>
328
- </div>
329
- </div>
330
- """
331
  return f"<div class='sources-container'>{items_html}</div>"
332
 
333
- async def generate_llm_answer(prompt: str) -> str:
334
- """Generates answer using the loaded LLM (Async Wrapper)."""
 
 
 
335
  if not llm_model or not llm_tokenizer:
336
  print("[LLM Generate] LLM model or tokenizer not available.")
337
  return "Error: Language Model is not available."
338
 
339
- print(f"[LLM Generate] Requesting generation (prompt length {len(prompt)})...")
340
  start_time = time.time()
341
  try:
 
 
 
 
 
 
342
  inputs = llm_tokenizer(
343
- prompt,
344
- return_tensors="pt",
345
- padding=True,
346
- truncation=True,
347
- max_length=1024, # Adjust based on model limits
348
- return_attention_mask=True
349
- ).to(llm_model.device)
350
 
351
  with torch.inference_mode(), torch.cuda.amp.autocast(enabled=(llm_model.dtype == torch.float16)):
352
- outputs = await asyncio.get_event_loop().run_in_executor(
353
- executor,
354
- llm_model.generate,
355
  inputs.input_ids,
356
  attention_mask=inputs.attention_mask,
357
  max_new_tokens=MAX_NEW_TOKENS,
358
- temperature=TEMPERATURE,
359
- top_p=TOP_P,
360
  pad_token_id=llm_tokenizer.eos_token_id,
361
  eos_token_id=llm_tokenizer.eos_token_id,
362
- do_sample=True,
363
- num_return_sequences=1
364
  )
365
 
366
- # Decode only newly generated tokens
367
  output_ids = outputs[0][inputs.input_ids.shape[1]:]
368
  answer_part = llm_tokenizer.decode(output_ids, skip_special_tokens=True).strip()
369
-
370
- if not answer_part:
371
- answer_part = "*Model generated an empty response.*"
372
 
373
  end_time = time.time()
374
  print(f"[LLM Generate] Generation complete in {end_time - start_time:.2f}s. Length: {len(answer_part)}")
@@ -377,46 +332,45 @@ async def generate_llm_answer(prompt: str) -> str:
377
  except Exception as e:
378
  print(f"[LLM Generate] Error: {e}")
379
  print(traceback.format_exc())
380
- return f"Error during answer generation: Check logs for details." # User-friendly error
 
381
 
382
- async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple[int, np.ndarray]]:
383
- """Generates speech using the loaded TTS model (Async Wrapper)."""
 
 
384
  if not TTS_ENABLED or not tts_model or 'generate_tts_internal' not in globals():
385
  print("[TTS Generate] Skipping: TTS not ready.")
386
  return None
387
- if not text or not text.strip() or text.startswith("Error:") or text.startswith("*Model generated"):
388
  print("[TTS Generate] Skipping: Invalid or empty text.")
389
  return None
390
 
391
- print(f"[TTS Generate] Requesting speech (length {len(text)}, voice '{voice_id}')...")
392
  start_time = time.time()
393
 
394
  try:
395
  actual_voice_id = voice_id
396
  if voice_id not in voicepacks:
397
- print(f"[TTS Generate] Warning: Voice '{voice_id}' not loaded. Trying default 'af'.")
398
  actual_voice_id = 'af'
399
- if 'af' not in voicepacks:
400
- print("[TTS Generate] Error: Default voice 'af' also not available.")
401
- return None
402
-
403
- # Clean text more thoroughly for TTS
404
- clean_text = re.sub(r'\[\d+\](\[\d+\])*', '', text) # Remove citations [1], [2][3]
405
- clean_text = re.sub(r'```.*?```', '', clean_text, flags=re.DOTALL) # Remove code blocks
406
- clean_text = re.sub(r'`[^`]*`', '', clean_text) # Remove inline code
407
- clean_text = re.sub(r'^\s*[\*->]\s*', '', clean_text, flags=re.MULTILINE) # Remove list markers/blockquotes at line start
408
- clean_text = re.sub(r'[\*#_]', '', clean_text) # Remove remaining markdown emphasis/headers
409
- clean_text = html.unescape(clean_text) # Decode HTML entities
410
- clean_text = ' '.join(clean_text.split()) # Normalize whitespace
411
-
412
- if not clean_text:
413
- print("[TTS Generate] Skipping: Text empty after cleaning.")
414
- return None
415
 
416
  if len(clean_text) > MAX_TTS_CHARS:
417
  print(f"[TTS Generate] Truncating cleaned text from {len(clean_text)} to {MAX_TTS_CHARS} chars.")
418
  clean_text = clean_text[:MAX_TTS_CHARS]
419
- last_punct = max(clean_text.rfind(p) for p in '.?!; ') # Find reasonable cut-off
420
  if last_punct != -1: clean_text = clean_text[:last_punct+1]
421
  clean_text += "..."
422
 
@@ -424,22 +378,39 @@ async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple
424
  gen_func = globals()['generate_tts_internal']
425
  voice_pack_data = voicepacks[actual_voice_id]
426
 
427
- # Execute in thread pool
428
- # Verify the expected language code ('afr', 'eng', etc.) for Kokoro
429
- audio_data, _ = await asyncio.get_event_loop().run_in_executor(
430
- executor, gen_func, tts_model, clean_text, voice_pack_data, 'afr'
431
- )
432
-
433
- # Process output
434
- if isinstance(audio_data, torch.Tensor):
435
- audio_np = audio_data.detach().cpu().numpy()
436
- elif isinstance(audio_data, np.ndarray):
437
- audio_np = audio_data
438
- else:
439
- print("[TTS Generate] Warning: Unexpected audio data type.")
440
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- audio_np = audio_np.flatten().astype(np.float32) # Ensure 1D float32
 
 
 
 
443
 
444
  end_time = time.time()
445
  print(f"[TTS Generate] Audio generated in {end_time - start_time:.2f}s. Shape: {audio_np.shape}")
@@ -450,97 +421,90 @@ async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple
450
  print(traceback.format_exc())
451
  return None
452
 
 
453
  def get_voice_id_from_display(voice_display_name: str) -> str:
454
- """Maps the user-friendly voice name to the internal voice ID."""
455
- return VOICE_CHOICES.get(voice_display_name, 'af') # Default to 'af'
456
 
457
 
458
- # --- Gradio Interaction Logic ---
459
- ChatHistoryType = List[Dict[str, Optional[str]]] # Allow None for content during streaming
460
 
461
- async def handle_interaction(
462
  query: str,
463
  history: ChatHistoryType,
464
  selected_voice_display_name: str
465
- ):
466
- """Main async generator function to handle user queries and update Gradio UI."""
467
- print(f"\n--- Handling Query ---")
468
- query = query.strip() # Clean input query
469
  print(f"Query: '{query}', Voice: '{selected_voice_display_name}'")
470
 
471
  if not query:
472
  print("Empty query received.")
473
- yield history, "*Please enter a non-empty query.*", "<div class='no-sources'>Enter a query to search.</div>", None, gr.Button(value="Search", interactive=True)
474
- return
475
 
476
- # Use 'messages' format: List of {'role': 'user'/'assistant', 'content': '...'}
477
  current_history: ChatHistoryType = history + [{"role": "user", "content": query}]
478
- # Add placeholder for assistant response
479
- current_history.append({"role": "assistant", "content": None}) # Content starts as None
480
-
481
- # Define states to yield
482
- chatbot_state = current_history
483
- status_state = "*Searching...*"
484
- sources_state = "<div class='searching'><span>Searching the web...</span></div>"
485
- audio_state = None
486
- button_state = gr.Button(value="Searching...", interactive=False)
487
-
488
- # 1. Initial State: Searching
489
- current_history[-1]["content"] = status_state # Update placeholder
490
- yield chatbot_state, status_state, sources_state, audio_state, button_state
491
-
492
- # 2. Perform Web Search (in executor)
493
- web_results = await asyncio.get_event_loop().run_in_executor(
494
- executor, get_web_results_sync, query
495
- )
496
- sources_state = format_sources_html(web_results)
497
-
498
- # Update state: Generating Answer
499
- status_state = "*Generating answer...*"
500
- button_state = gr.Button(value="Generating...", interactive=False)
501
- current_history[-1]["content"] = status_state # Update placeholder
502
- yield chatbot_state, status_state, sources_state, audio_state, button_state
503
-
504
- # 3. Generate LLM Answer (async)
505
- llm_prompt = format_llm_prompt(query, web_results)
506
- final_answer = await generate_llm_answer(llm_prompt)
507
- status_state = final_answer # Now status holds the actual answer
508
-
509
- # Update assistant message in history fully
510
- current_history[-1]["content"] = final_answer
511
-
512
- # Update state: Generating Audio (if applicable)
513
- button_state = gr.Button(value="Audio...", interactive=False) if TTS_ENABLED else gr.Button(value="Search", interactive=True)
514
- yield chatbot_state, status_state, sources_state, audio_state, button_state
515
-
516
- # 4. Generate TTS Speech (async)
517
- tts_status_message = ""
518
- if not TTS_ENABLED:
519
- if tts_setup_thread.is_alive():
520
- tts_status_message = "\n\n*(TTS initializing...)*"
521
- else:
522
- # Check if setup failed vs just disabled
523
- # This info isn't easily available here, assume failed/disabled
524
- tts_status_message = "\n\n*(TTS unavailable)*"
525
- else:
526
- voice_id = get_voice_id_from_display(selected_voice_display_name)
527
- audio_state = await generate_tts_speech(final_answer, voice_id) # Returns (rate, data) or None
528
- if audio_state is None and not final_answer.startswith("Error"): # Don't show TTS fail if LLM failed
529
- tts_status_message = "\n\n*(Audio generation failed)*"
530
 
531
- # 5. Final State: Show all results
532
- final_answer_with_status = final_answer + tts_status_message
533
- status_state = final_answer_with_status # Update status display
534
- current_history[-1]["content"] = final_answer_with_status # Update history *again* with status msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
- button_state = gr.Button(value="Search", interactive=True) # Re-enable button
 
 
 
 
 
 
 
 
537
 
538
- print("--- Query Handling Complete ---")
539
- yield chatbot_state, status_state, sources_state, audio_state, button_state
540
 
541
 
542
  # --- Gradio UI Definition ---
543
- # (CSS from previous response)
544
  css = """
545
  /* ... [Your existing refined CSS] ... */
546
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
@@ -559,7 +523,7 @@ css = """
559
  .search-box button:hover { background: #1d4ed8 !important; }
560
  .search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
561
  .results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
562
- .answer-box { /* Now used for status/interim text */ background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1rem; color: #1f2937; margin-bottom: 0.5rem; box-shadow: 0 2px 8px rgba(0,0,0,0.05); min-height: 50px;}
563
  .answer-box p { color: #374151; line-height: 1.7; margin:0;}
564
  .answer-box code { background: #f3f4f6; border-radius: 4px; padding: 2px 4px; color: #4b5563; font-size: 0.9em; }
565
  .sources-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; }
@@ -572,8 +536,8 @@ css = """
572
  .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}
573
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
574
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
575
- .chat-history { /* Style the chatbot container */ max-height: 500px; overflow-y: auto; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; /* margin-top: 1rem; */ scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
576
- .chat-history > div { padding: 1rem; } /* Add padding inside the chatbot display area */
577
  .chat-history::-webkit-scrollbar { width: 6px; }
578
  .chat-history::-webkit-scrollbar-track { background: #f9fafb; }
579
  .chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
@@ -594,8 +558,6 @@ css = """
594
  .markdown-content table { border-collapse: collapse !important; width: 100% !important; margin: 1em 0; }
595
  .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
596
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
597
- /* .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; } */
598
- /* .accordion > .label-wrap { padding: 10px 15px !important; } */
599
  .voice-selector { margin: 0; padding: 0; height: 100%; }
600
  .voice-selector div[data-testid="dropdown"] { height: 100% !important; border-radius: 0 !important;}
601
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
@@ -646,8 +608,6 @@ css = """
646
  .dark .markdown-content blockquote { border-left-color: #4b5563 !important; color: #9ca3af !important; }
647
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
648
  .dark .markdown-content th { background: #374151 !important; }
649
- /* .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; } */
650
- /* .dark .accordion > .label-wrap { color: #d1d5db !important; } */
651
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
652
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
653
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
@@ -660,125 +620,69 @@ css = """
660
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
661
  """
662
 
663
- with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
664
- # Use gr.State for chat history in 'messages' format
665
  chat_history_state = gr.State([])
666
 
667
  with gr.Column():
668
- # Header
669
  with gr.Column(elem_id="header"):
670
- gr.Markdown("# ๐Ÿ” AI Search Assistant")
671
  gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
 
672
 
673
- # Search Area
674
  with gr.Column(elem_classes="search-container"):
675
  with gr.Row(elem_classes="search-box"):
676
  search_input = gr.Textbox(label="", placeholder="Ask anything...", scale=5, container=False)
677
  voice_select = gr.Dropdown(choices=list(VOICE_CHOICES.keys()), value=list(VOICE_CHOICES.keys())[0], label="", scale=1, min_width=180, container=False, elem_classes="voice-selector")
678
  search_btn = gr.Button("Search", variant="primary", scale=0, min_width=100)
679
 
680
- # Results Area
681
  with gr.Row(elem_classes="results-container"):
682
- # Left Column: Chatbot, Status, Audio
683
  with gr.Column(scale=3):
684
  chatbot_display = gr.Chatbot(
685
- label="Conversation",
686
- bubble_full_width=True,
687
- height=500, # Adjusted height
688
- elem_classes="chat-history",
689
- type="messages", # IMPORTANT: Use 'messages' format
690
- show_label=False,
691
- avatar_images=(None, os.path.join(KOKORO_PATH, "icon.png") if os.path.exists(os.path.join(KOKORO_PATH, "icon.png")) else "https://huggingface.co/spaces/gradio/chatbot-streaming/resolve/main/avatar.png") # User/Assistant avatars
692
  )
 
693
  answer_status_output = gr.Markdown(value="*Enter a query to start.*", elem_classes="answer-box markdown-content")
694
  audio_player = gr.Audio(label="Voice Response", type="numpy", autoplay=False, show_label=False, elem_classes="audio-player")
695
 
696
- # Right Column: Sources
697
  with gr.Column(scale=2):
698
  with gr.Column(elem_classes="sources-box"):
699
  gr.Markdown("### Sources")
700
  sources_output_html = gr.HTML(value="<div class='no-sources'>Sources will appear here.</div>")
701
 
702
- # Examples Area
703
  with gr.Row(elem_classes="examples-container"):
704
  gr.Examples(
705
- examples=[
706
- "Latest news about renewable energy",
707
- "Explain Large Language Models (LLMs)",
708
- "Symptoms and prevention tips for the flu",
709
- "Compare Python and JavaScript for web development",
710
- "Summarize the main points of the Paris Agreement",
711
- ],
712
- inputs=search_input,
713
- label="Try these examples:",
714
- # elem_classes removed
715
  )
716
 
717
- # --- Event Handling Setup ---
718
  event_inputs = [search_input, chat_history_state, voice_select]
719
- event_outputs = [
720
- chatbot_display, # Output 1: Updated chat history
721
- answer_status_output, # Output 2: Status/final text
722
- sources_output_html, # Output 3: Sources HTML
723
- audio_player, # Output 4: Audio data
724
- search_btn # Output 5: Button state
725
- ]
726
-
727
- async def stream_interaction_updates(query, history, voice_display_name):
728
- """Wraps the async generator to handle streaming updates and errors."""
729
- print("[Gradio Stream] Starting interaction...")
730
- final_state_tuple = None # To store the last successful state
731
- try:
732
- async for state_update_tuple in handle_interaction(query, history, voice_display_name):
733
- yield state_update_tuple # Yield the tuple for Gradio to update outputs
734
- final_state_tuple = state_update_tuple # Keep track of the last state
735
- print("[Gradio Stream] Interaction completed successfully.")
736
-
737
- except Exception as e:
738
- print(f"[Gradio Stream] Error during interaction: {e}")
739
- print(traceback.format_exc())
740
- # Construct error state to yield
741
- error_history = history + [{"role":"user", "content":query}, {"role":"assistant", "content":f"*An error occurred. Please check logs.*"}]
742
- error_state_tuple = (
743
- error_history,
744
- f"An error occurred: {e}",
745
- "<div class='error'>Request failed.</div>",
746
- None,
747
- gr.Button(value="Search", interactive=True) # Ensure button is re-enabled
748
- )
749
- yield error_state_tuple # Yield the error state to UI
750
- final_state_tuple = error_state_tuple # Store error state as last state
751
-
752
- # Optionally clear input ONLY if the interaction finished (success or error)
753
- # Requires adding search_input to event_outputs and handling the update dict
754
- # Example (if search_input is the 6th output):
755
- # if final_state_tuple:
756
- # yield (*final_state_tuple, gr.Textbox(value=""))
757
- # else: # Handle case where no state was ever yielded (e.g., immediate empty query return)
758
- # yield (history, "*Please enter a query.*", "...", None, gr.Button(value="Search", interactive=True), gr.Textbox(value=""))
759
-
760
 
761
- # Connect the streaming function
762
  search_btn.click(
763
- fn=stream_interaction_updates,
764
  inputs=event_inputs,
765
  outputs=event_outputs
766
  )
767
  search_input.submit(
768
- fn=stream_interaction_updates,
769
  inputs=event_inputs,
770
  outputs=event_outputs
771
  )
772
 
773
  # --- Main Execution ---
774
  if __name__ == "__main__":
775
- print("Starting Gradio application...")
776
- # Optional: Wait a moment for TTS setup thread to start and potentially print messages
777
- # time.sleep(1)
778
  demo.queue(max_size=20).launch(
779
  debug=True,
780
- share=True, # Set to False if not running on Spaces or don't need public link
781
- # server_name="0.0.0.0", # Uncomment to bind to all network interfaces
782
- # server_port=7860 # Optional: Specify port
783
  )
784
  print("Gradio application stopped.")
 
10
  import numpy as np
11
  from typing import List, Dict, Tuple, Any, Optional, Union
12
  from functools import lru_cache
13
+ # No asyncio needed for synchronous version
14
  import threading
15
+ # No ThreadPoolExecutor needed for synchronous version
16
  import warnings
17
  import traceback # For detailed error logging
18
  import re # For text cleaning
19
  import shutil # For checking sudo/file operations
20
  import html # For escaping HTML
21
  import sys # For sys.path manipulation
22
+ import spaces # <<<--- IMPORT SPACES FOR THE DECORATOR
23
 
24
  # --- Configuration ---
25
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
26
  MAX_SEARCH_RESULTS = 5
27
  TTS_SAMPLE_RATE = 24000
28
+ MAX_TTS_CHARS = 1000
29
  MAX_NEW_TOKENS = 300
30
  TEMPERATURE = 0.7
31
  TOP_P = 0.95
32
+ KOKORO_PATH = 'Kokoro-82M'
33
+ # Define expected durations for ZeroGPU decorator
34
+ LLM_GPU_DURATION = 120 # Seconds (adjust based on expected LLM generation time)
35
+ TTS_GPU_DURATION = 45 # Seconds (adjust based on expected TTS generation time)
36
 
37
  # --- Initialization ---
 
 
 
38
  # Suppress specific warnings
39
  warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
40
  warnings.filterwarnings("ignore", message="Backend 'inductor' is not available.")
 
49
  llm_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
50
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
51
 
52
+ # For ZeroGPU, we assume GPU will be available when needed, load with cuda preference
53
+ # If running locally without GPU, it might try CPU based on device_map="auto" fallback
54
+ llm_device = "cuda" if torch.cuda.is_available() else "cpu" # Check initial availability info
55
+ torch_dtype = torch.float16 if llm_device == "cuda" else torch.float32
56
+ # device_map="auto" is generally okay, ZeroGPU handles the actual assignment during decorated function call
57
+ device_map = "auto"
58
+ print(f"[LLM Init] Preparing model load (target device via ZeroGPU: cuda, dtype={torch_dtype})")
 
 
 
59
 
60
  llm_model = AutoModelForCausalLM.from_pretrained(
61
  MODEL_NAME,
62
+ device_map=device_map, # Let accelerate/ZeroGPU handle placement
63
  low_cpu_mem_usage=True,
64
  torch_dtype=torch_dtype,
 
65
  )
66
+ print(f"[LLM Init] LLM loaded configuration successfully. Ready for GPU assignment via @spaces.GPU.")
 
 
67
  llm_model.eval()
68
 
69
  except Exception as e:
 
75
 
76
 
77
  # --- TTS Initialization ---
78
+ # (TTS setup remains the same, runs in background)
79
  VOICE_CHOICES = {
80
  '๐Ÿ‡บ๐Ÿ‡ธ Female (Default)': 'af',
81
  '๐Ÿ‡บ๐Ÿ‡ธ Bella': 'af_bella',
 
87
  voicepacks: Dict[str, Any] = {}
88
  tts_device = "cpu"
89
 
 
90
  def _run_subprocess(cmd: List[str], check: bool = True, cwd: Optional[str] = None, timeout: int = 300) -> subprocess.CompletedProcess:
91
  """Runs a subprocess command, captures output, and handles errors."""
92
  print(f"Running command: {' '.join(cmd)}")
93
  try:
94
  result = subprocess.run(cmd, check=check, capture_output=True, text=True, cwd=cwd, timeout=timeout)
 
95
  if not check or result.returncode != 0:
96
+ if result.stdout: print(f" Stdout: {result.stdout.strip()}")
97
+ if result.stderr: print(f" Stderr: {result.stderr.strip()}")
98
  elif result.returncode == 0 and ('clone' in cmd or 'pull' in cmd or 'install' in cmd):
99
+ print(f" Command successful.")
100
  return result
101
  except FileNotFoundError:
102
  print(f" Error: Command not found - {cmd[0]}")
 
110
  if e.stderr: print(f" Stderr: {e.stderr.strip()}")
111
  raise
112
 
 
113
  def setup_tts_task():
114
  """Initializes Kokoro TTS model and dependencies."""
115
  global TTS_ENABLED, tts_model, voicepacks, tts_device
116
  print("[TTS Setup] Starting background initialization...")
117
 
118
+ # TTS device determination depends on where generate_tts_speech will run.
119
+ # If decorated with @spaces.GPU, it will use CUDA when called.
120
+ tts_device = "cuda" # Assume it will run on GPU via decorator
121
+ print(f"[TTS Setup] Target device for TTS model (via @spaces.GPU): {tts_device}")
122
 
123
  can_sudo = shutil.which('sudo') is not None
124
  apt_cmd_prefix = ['sudo'] if can_sudo else []
125
+ absolute_kokoro_path = os.path.abspath(KOKORO_PATH)
126
 
127
  try:
128
+ # 1. Clone/Update Repo
129
  if not os.path.exists(absolute_kokoro_path):
130
+ print(f"[TTS Setup] Cloning repository to {absolute_kokoro_path}...")
131
+ # (Cloning logic as before)
132
+ try: _run_subprocess(['git', 'lfs', 'install', '--system', '--skip-repo'])
133
+ except Exception as lfs_err: print(f"[TTS Setup] Warning: git lfs install failed: {lfs_err}")
134
+ _run_subprocess(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M', absolute_kokoro_path])
135
+ try: _run_subprocess(['git', 'lfs', 'pull'], cwd=absolute_kokoro_path)
136
+ except Exception as lfs_pull_err: print(f"[TTS Setup] Warning: git lfs pull failed: {lfs_pull_err}")
 
 
 
 
137
  else:
138
+ print(f"[TTS Setup] Directory {absolute_kokoro_path} already exists.")
139
+
140
+ # 2. Install espeak
 
 
 
 
 
 
 
141
  print("[TTS Setup] Checking/Installing espeak...")
142
+ try: # (espeak install logic as before)
143
+ _run_subprocess(apt_cmd_prefix + ['apt-get', 'update', '-qq'])
144
+ _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak-ng'])
145
+ print("[TTS Setup] espeak-ng installed or already present.")
 
 
146
  except Exception:
147
+ print("[TTS Setup] espeak-ng installation failed, trying espeak...")
148
+ try:
149
+ _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak'])
150
+ print("[TTS Setup] espeak installed or already present.")
151
+ except Exception as espeak_err:
152
+ print(f"[TTS Setup] ERROR: Failed to install espeak: {espeak_err}. TTS disabled.")
153
+ return
 
154
 
155
  # 3. Load Kokoro Model and Voices
156
  sys_path_updated = False
157
  if os.path.exists(absolute_kokoro_path):
158
+ print(f"[TTS Setup] Checking contents of: {absolute_kokoro_path}")
159
+ try: print(f"[TTS Setup] Contents: {os.listdir(absolute_kokoro_path)}")
160
+ except OSError as list_err: print(f"[TTS Setup] Warning: Could not list directory contents: {list_err}")
161
+
162
+ if absolute_kokoro_path not in sys.path:
163
+ sys.path.insert(0, absolute_kokoro_path)
164
+ sys_path_updated = True
165
+ print(f"[TTS Setup] Temporarily added {absolute_kokoro_path} to sys.path.")
166
+
167
+ try:
168
+ print("[TTS Setup] Attempting to import Kokoro modules...")
169
+ from models import build_model
170
+ from kokoro import generate as generate_tts_internal
171
+ print("[TTS Setup] Kokoro modules imported successfully.")
172
+
173
+ globals()['build_model'] = build_model
174
+ globals()['generate_tts_internal'] = generate_tts_internal
175
+
176
+ model_file = os.path.join(absolute_kokoro_path, 'kokoro-v0_19.pth')
177
+ if not os.path.exists(model_file):
178
+ print(f"[TTS Setup] ERROR: Model file {model_file} not found. TTS disabled.")
179
+ return
180
+
181
+ # Load model onto CPU initially, ZeroGPU decorator will handle moving/using GPU
182
+ print(f"[TTS Setup] Loading TTS model config from {model_file} (target device: {tts_device} via @spaces.GPU)...")
183
+ # Load onto CPU first to avoid issues before GPU is attached.
184
+ # The build_model function might need adjustment if it forces device placement.
185
+ # Assuming build_model can load structure then decorator handles device use.
186
+ # If build_model *requires* device at load, this might need adjustment.
187
+ tts_model = build_model(model_file, 'cpu') # <<< Load to CPU first
188
+ tts_model.eval()
189
+ print("[TTS Setup] TTS model structure loaded (CPU).")
190
+
191
+ # Load voices onto CPU
192
+ loaded_voices = 0
193
+ for voice_name, voice_id in VOICE_CHOICES.items():
194
+ voice_file_path = os.path.join(absolute_kokoro_path, 'voices', f'{voice_id}.pt')
195
+ if os.path.exists(voice_file_path):
196
+ try:
197
+ print(f"[TTS Setup] Loading voice: {voice_id} ({voice_name}) to CPU")
198
+ voicepacks[voice_id] = torch.load(voice_file_path, map_location='cpu') # <<< Load to CPU
199
+ loaded_voices += 1
200
+ except Exception as e: print(f"[TTS Setup] Warning: Failed to load voice {voice_id}: {str(e)}")
201
+ else: print(f"[TTS Setup] Info: Voice file {voice_file_path} not found.")
202
+
203
+ if loaded_voices == 0:
204
+ print("[TTS Setup] ERROR: No voicepacks loaded. TTS disabled.")
205
+ tts_model = None; return
206
+
207
+ TTS_ENABLED = True
208
+ print(f"[TTS Setup] Initialization successful. {loaded_voices} voices loaded. TTS Enabled: {TTS_ENABLED}")
209
+
210
+ except ImportError as ie:
211
+ print(f"[TTS Setup] ERROR: Failed to import Kokoro modules: {ie}.")
212
+ print(traceback.format_exc())
213
+ except Exception as load_err:
214
+ print(f"[TTS Setup] ERROR: Exception during TTS model/voice loading: {load_err}. TTS disabled.")
215
+ print(traceback.format_exc())
216
+ finally:
217
+ if sys_path_updated: # Cleanup sys.path
218
+ try:
219
+ if sys.path[0] == absolute_kokoro_path: sys.path.pop(0)
220
+ elif absolute_kokoro_path in sys.path: sys.path.remove(absolute_kokoro_path)
221
+ print(f"[TTS Setup] Cleaned up sys.path.")
222
+ except Exception as cleanup_err: print(f"[TTS Setup] Warning: Error cleaning sys.path: {cleanup_err}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  else:
224
  print(f"[TTS Setup] ERROR: Directory {absolute_kokoro_path} not found. TTS disabled.")
225
 
226
  except Exception as e:
227
  print(f"[TTS Setup] ERROR: Unexpected error during setup: {str(e)}")
228
  print(traceback.format_exc())
229
+ TTS_ENABLED = False; tts_model = None; voicepacks.clear()
 
 
230
 
231
+ # Start TTS setup thread
232
  print("Starting TTS setup thread...")
233
  tts_setup_thread = threading.Thread(target=setup_tts_task, daemon=True)
234
  tts_setup_thread.start()
235
 
236
 
237
+ # --- Core Logic Functions (SYNCHRONOUS + @spaces.GPU) ---
238
 
239
+ # Web search remains synchronous
240
  @lru_cache(maxsize=128)
241
  def get_web_results_sync(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, Any]]:
242
  """Synchronous web search function with caching."""
243
+ # (Implementation remains the same as before)
244
  print(f"[Web Search] Searching (sync): '{query}' (max_results={max_results})")
245
  try:
246
  with DDGS() as ddgs:
247
  results = list(ddgs.text(query, max_results=max_results, safesearch='moderate', timelimit='y'))
248
  print(f"[Web Search] Found {len(results)} results.")
249
  formatted = [{
250
+ "id": i + 1, "title": res.get("title", "No Title"),
251
+ "snippet": res.get("body", "No Snippet"), "url": res.get("href", "#"),
 
 
252
  } for i, res in enumerate(results)]
253
  return formatted
254
  except Exception as e:
255
+ print(f"[Web Search] Error: {e}"); return []
 
 
256
 
257
+ # Prompt formatting remains the same
258
  def format_llm_prompt(query: str, context: List[Dict[str, Any]]) -> str:
259
+ """Formats the prompt for the LLM."""
260
+ # (Implementation remains the same as before)
261
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
262
  context_str = "\n\n".join(
263
  [f"[{res['id']}] {html.escape(res['title'])}\n{html.escape(res['snippet'])}" for res in context]
264
  ) if context else "No relevant web context found."
 
 
265
  return f"""SYSTEM: You are a helpful AI assistant. Answer the user's query based *only* on the provided web search context. Cite sources using bracket notation like [1], [2]. If the context is insufficient, state that clearly. Use markdown for formatting. Do not add external information. Current Time: {current_time}
266
 
267
  CONTEXT:
 
271
 
272
  USER: {html.escape(query)}
273
 
274
+ ASSISTANT:"""
275
 
276
+ # Source formatting remains the same
277
  def format_sources_html(web_results: List[Dict[str, Any]]) -> str:
278
  """Formats search results into HTML for display."""
279
+ # (Implementation remains the same as before)
280
+ if not web_results: return "<div class='no-sources'>No sources found.</div>"
281
  items_html = ""
282
  for res in web_results:
283
  title_safe = html.escape(res.get("title", "Source"))
284
  snippet_safe = html.escape(res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else ""))
285
+ url = html.escape(res.get("url", "#"))
286
+ items_html += f"""<div class='source-item'><div class='source-number'>[{res['id']}]</div><div class='source-content'><a href="{url}" target="_blank" class='source-title' title="{url}">{title_safe}</a><div class='source-snippet'>{snippet_safe}</div></div></div>"""
 
 
 
 
 
 
 
 
287
  return f"<div class='sources-container'>{items_html}</div>"
288
 
289
+
290
+ # <<<--- ADD @spaces.GPU decorator AND MAKE SYNCHRONOUS --->>>
291
+ @spaces.GPU(duration=LLM_GPU_DURATION)
292
+ def generate_llm_answer(prompt: str) -> str:
293
+ """Generates answer using the LLM (Synchronous, GPU-decorated)."""
294
  if not llm_model or not llm_tokenizer:
295
  print("[LLM Generate] LLM model or tokenizer not available.")
296
  return "Error: Language Model is not available."
297
 
298
+ print(f"[LLM Generate] Requesting generation (sync, GPU) (prompt length {len(prompt)})...")
299
  start_time = time.time()
300
  try:
301
+ # Ensure model is on the GPU (ZeroGPU should handle this)
302
+ # It might be safer to explicitly move model IF ZeroGPU doesn't guarantee it.
303
+ # Let's assume ZeroGPU handles the context for now.
304
+ current_device = next(llm_model.parameters()).device
305
+ print(f"[LLM Generate] Model currently on device: {current_device}") # Debug device
306
+
307
  inputs = llm_tokenizer(
308
+ prompt, return_tensors="pt", padding=True, truncation=True,
309
+ max_length=1024, return_attention_mask=True
310
+ ).to(current_device) # Send input to model's device
 
 
 
 
311
 
312
  with torch.inference_mode(), torch.cuda.amp.autocast(enabled=(llm_model.dtype == torch.float16)):
313
+ # Direct synchronous call
314
+ outputs = llm_model.generate(
 
315
  inputs.input_ids,
316
  attention_mask=inputs.attention_mask,
317
  max_new_tokens=MAX_NEW_TOKENS,
318
+ temperature=TEMPERATURE, top_p=TOP_P,
 
319
  pad_token_id=llm_tokenizer.eos_token_id,
320
  eos_token_id=llm_tokenizer.eos_token_id,
321
+ do_sample=True, num_return_sequences=1
 
322
  )
323
 
 
324
  output_ids = outputs[0][inputs.input_ids.shape[1]:]
325
  answer_part = llm_tokenizer.decode(output_ids, skip_special_tokens=True).strip()
326
+ if not answer_part: answer_part = "*Model generated an empty response.*"
 
 
327
 
328
  end_time = time.time()
329
  print(f"[LLM Generate] Generation complete in {end_time - start_time:.2f}s. Length: {len(answer_part)}")
 
332
  except Exception as e:
333
  print(f"[LLM Generate] Error: {e}")
334
  print(traceback.format_exc())
335
+ return f"Error during answer generation: Check logs."
336
+
337
 
338
+ # <<<--- ADD @spaces.GPU decorator AND MAKE SYNCHRONOUS --->>>
339
+ @spaces.GPU(duration=TTS_GPU_DURATION)
340
+ def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple[int, np.ndarray]]:
341
+ """Generates speech using TTS model (Synchronous, GPU-decorated)."""
342
  if not TTS_ENABLED or not tts_model or 'generate_tts_internal' not in globals():
343
  print("[TTS Generate] Skipping: TTS not ready.")
344
  return None
345
+ if not text or not text.strip() or text.startswith("Error:") or text.startswith("*Model"):
346
  print("[TTS Generate] Skipping: Invalid or empty text.")
347
  return None
348
 
349
+ print(f"[TTS Generate] Requesting speech (sync, GPU) (length {len(text)}, voice '{voice_id}')...")
350
  start_time = time.time()
351
 
352
  try:
353
  actual_voice_id = voice_id
354
  if voice_id not in voicepacks:
355
+ print(f"[TTS Generate] Warning: Voice '{voice_id}' not loaded. Trying 'af'.")
356
  actual_voice_id = 'af'
357
+ if 'af' not in voicepacks: print("[TTS Generate] Error: Default voice 'af' unavailable."); return None
358
+
359
+ # Clean text (same cleaning logic as before)
360
+ clean_text = re.sub(r'\[\d+\](\[\d+\])*', '', text)
361
+ clean_text = re.sub(r'```.*?```', '', clean_text, flags=re.DOTALL)
362
+ clean_text = re.sub(r'`[^`]*`', '', clean_text)
363
+ clean_text = re.sub(r'^\s*[\*->]\s*', '', clean_text, flags=re.MULTILINE)
364
+ clean_text = re.sub(r'[\*#_]', '', clean_text)
365
+ clean_text = html.unescape(clean_text)
366
+ clean_text = ' '.join(clean_text.split())
367
+
368
+ if not clean_text: print("[TTS Generate] Skipping: Text empty after cleaning."); return None
 
 
 
 
369
 
370
  if len(clean_text) > MAX_TTS_CHARS:
371
  print(f"[TTS Generate] Truncating cleaned text from {len(clean_text)} to {MAX_TTS_CHARS} chars.")
372
  clean_text = clean_text[:MAX_TTS_CHARS]
373
+ last_punct = max(clean_text.rfind(p) for p in '.?!; ')
374
  if last_punct != -1: clean_text = clean_text[:last_punct+1]
375
  clean_text += "..."
376
 
 
378
  gen_func = globals()['generate_tts_internal']
379
  voice_pack_data = voicepacks[actual_voice_id]
380
 
381
+ # *** Crucial for ZeroGPU: Move TTS model and voicepack to CUDA within the decorated function ***
382
+ current_device = 'cuda' # Assume GPU is attached by decorator
383
+ try:
384
+ print(f"[TTS Generate] Moving TTS model to {current_device}...")
385
+ tts_model.to(current_device)
386
+ # Move voicepack data (might be a dict of tensors)
387
+ if isinstance(voice_pack_data, dict):
388
+ moved_voice_pack = {k: v.to(current_device) if isinstance(v, torch.Tensor) else v for k, v in voice_pack_data.items()}
389
+ elif isinstance(voice_pack_data, torch.Tensor):
390
+ moved_voice_pack = voice_pack_data.to(current_device)
391
+ else:
392
+ moved_voice_pack = voice_pack_data # Assume not tensors if not dict/tensor
393
+ print(f"[TTS Generate] TTS model and voicepack on {current_device}.")
394
+
395
+ # Direct synchronous call on GPU
396
+ audio_data, _ = gen_func(tts_model, clean_text, moved_voice_pack, 'afr')
397
+
398
+ finally:
399
+ # *** Optional but recommended: Move model back to CPU to free GPU memory if needed ***
400
+ # ZeroGPU might handle this, but explicit move-back can be safer if running locally too
401
+ try:
402
+ print("[TTS Generate] Moving TTS model back to CPU...")
403
+ tts_model.to('cpu')
404
+ # No need to move voicepack back, it's loaded to CPU initially
405
+ except Exception as move_back_err:
406
+ print(f"[TTS Generate] Warning: Could not move TTS model back to CPU: {move_back_err}")
407
+
408
 
409
+ # Process output (remains same)
410
+ if isinstance(audio_data, torch.Tensor): audio_np = audio_data.detach().cpu().numpy()
411
+ elif isinstance(audio_data, np.ndarray): audio_np = audio_data
412
+ else: print("[TTS Generate] Warning: Unexpected audio data type."); return None
413
+ audio_np = audio_np.flatten().astype(np.float32)
414
 
415
  end_time = time.time()
416
  print(f"[TTS Generate] Audio generated in {end_time - start_time:.2f}s. Shape: {audio_np.shape}")
 
421
  print(traceback.format_exc())
422
  return None
423
 
424
+ # Voice ID mapping remains same
425
  def get_voice_id_from_display(voice_display_name: str) -> str:
426
+ return VOICE_CHOICES.get(voice_display_name, 'af')
 
427
 
428
 
429
+ # --- Gradio Interaction Logic (SYNCHRONOUS) ---
430
+ ChatHistoryType = List[Dict[str, Optional[str]]]
431
 
432
+ def handle_interaction(
433
  query: str,
434
  history: ChatHistoryType,
435
  selected_voice_display_name: str
436
+ ) -> Tuple[ChatHistoryType, str, str, Optional[Tuple[int, np.ndarray]], Any]: # Return type matches outputs
437
+ """Synchronous function to handle user queries for ZeroGPU."""
438
+ print(f"\n--- Handling Query (Sync) ---")
439
+ query = query.strip()
440
  print(f"Query: '{query}', Voice: '{selected_voice_display_name}'")
441
 
442
  if not query:
443
  print("Empty query received.")
444
+ # Return initial state immediately
445
+ return history, "*Please enter a non-empty query.*", "<div class='no-sources'>Enter a query to search.</div>", None, gr.Button(value="Search", interactive=True)
446
 
447
+ # Initial state updates (won't be seen until the end in Gradio)
448
  current_history: ChatHistoryType = history + [{"role": "user", "content": query}]
449
+ current_history.append({"role": "assistant", "content": "*Processing... Please wait.*"}) # Placeholder
450
+ status_update = "*Processing... Please wait.*"
451
+ sources_html = "<div class='searching'><span>Searching & Processing...</span></div>"
452
+ audio_data = None
453
+ button_update = gr.Button(value="Processing...", interactive=False) # Disabled during processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ # --- Start Blocking Operations ---
456
+ try:
457
+ # 1. Perform Web Search (Sync)
458
+ print("[Handler] Performing web search...")
459
+ web_results = get_web_results_sync(query)
460
+ sources_html = format_sources_html(web_results) # Update sources now
461
+
462
+ # 2. Generate LLM Answer (Sync, Decorated)
463
+ print("[Handler] Generating LLM answer...")
464
+ status_update = "*Generating answer...*" # Update status text
465
+ # (UI won't update here yet)
466
+ llm_prompt = format_llm_prompt(query, web_results)
467
+ final_answer = generate_llm_answer(llm_prompt) # This call triggers GPU attachment
468
+ status_update = final_answer # Answer generated
469
+
470
+ # 3. Generate TTS Speech (Sync, Decorated, Optional)
471
+ tts_status_message = ""
472
+ if TTS_ENABLED and not final_answer.startswith("Error"):
473
+ print("[Handler] Generating TTS speech...")
474
+ status_update += "\n\n*(Generating audio...)*" # Append status
475
+ # (UI won't update here yet)
476
+ voice_id = get_voice_id_from_display(selected_voice_display_name)
477
+ audio_data = generate_tts_speech(final_answer, voice_id) # This call triggers GPU attachment
478
+ if audio_data is None:
479
+ tts_status_message = "\n\n*(Audio generation failed)*"
480
+ elif not TTS_ENABLED:
481
+ if tts_setup_thread.is_alive(): tts_status_message = "\n\n*(TTS initializing...)*"
482
+ else: tts_status_message = "\n\n*(TTS unavailable)*"
483
+
484
+ # Combine final answer with status
485
+ final_answer_with_status = final_answer + tts_status_message
486
+ status_update = final_answer_with_status
487
+ current_history[-1]["content"] = final_answer_with_status # Update history
488
+
489
+ button_update = gr.Button(value="Search", interactive=True) # Re-enable button
490
+ print("--- Query Handling Complete (Sync) ---")
491
 
492
+ except Exception as e:
493
+ print(f"[Handler] Error during processing: {e}")
494
+ print(traceback.format_exc())
495
+ error_message = f"*An error occurred: {e}*"
496
+ current_history[-1]["content"] = error_message # Update history with error
497
+ status_update = error_message
498
+ sources_html = "<div class='error'>Request failed.</div>"
499
+ audio_data = None
500
+ button_update = gr.Button(value="Search", interactive=True) # Re-enable button on error
501
 
502
+ # Return the final state tuple for all outputs
503
+ return current_history, status_update, sources_html, audio_data, button_update
504
 
505
 
506
  # --- Gradio UI Definition ---
507
+ # (CSS remains the same)
508
  css = """
509
  /* ... [Your existing refined CSS] ... */
510
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
 
523
  .search-box button:hover { background: #1d4ed8 !important; }
524
  .search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
525
  .results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
526
+ .answer-box { /* Now used for status/final text */ background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1rem; color: #1f2937; margin-bottom: 0.5rem; box-shadow: 0 2px 8px rgba(0,0,0,0.05); min-height: 50px;}
527
  .answer-box p { color: #374151; line-height: 1.7; margin:0;}
528
  .answer-box code { background: #f3f4f6; border-radius: 4px; padding: 2px 4px; color: #4b5563; font-size: 0.9em; }
529
  .sources-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; }
 
536
  .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}
537
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
538
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
539
+ .chat-history { max-height: 500px; overflow-y: auto; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
540
+ .chat-history > div { padding: 1rem; }
541
  .chat-history::-webkit-scrollbar { width: 6px; }
542
  .chat-history::-webkit-scrollbar-track { background: #f9fafb; }
543
  .chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
 
558
  .markdown-content table { border-collapse: collapse !important; width: 100% !important; margin: 1em 0; }
559
  .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
560
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
 
 
561
  .voice-selector { margin: 0; padding: 0; height: 100%; }
562
  .voice-selector div[data-testid="dropdown"] { height: 100% !important; border-radius: 0 !important;}
563
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
 
608
  .dark .markdown-content blockquote { border-left-color: #4b5563 !important; color: #9ca3af !important; }
609
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
610
  .dark .markdown-content th { background: #374151 !important; }
 
 
611
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
612
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
613
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
 
620
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
621
  """
622
 
623
+ with gr.Blocks(title="AI Search Assistant (ZeroGPU Sync)", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
 
624
  chat_history_state = gr.State([])
625
 
626
  with gr.Column():
 
627
  with gr.Column(elem_id="header"):
628
+ gr.Markdown("# ๐Ÿ” AI Search Assistant (ZeroGPU Version)")
629
  gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
630
+ gr.Markdown("*(UI will block during processing for ZeroGPU compatibility)*")
631
 
 
632
  with gr.Column(elem_classes="search-container"):
633
  with gr.Row(elem_classes="search-box"):
634
  search_input = gr.Textbox(label="", placeholder="Ask anything...", scale=5, container=False)
635
  voice_select = gr.Dropdown(choices=list(VOICE_CHOICES.keys()), value=list(VOICE_CHOICES.keys())[0], label="", scale=1, min_width=180, container=False, elem_classes="voice-selector")
636
  search_btn = gr.Button("Search", variant="primary", scale=0, min_width=100)
637
 
 
638
  with gr.Row(elem_classes="results-container"):
 
639
  with gr.Column(scale=3):
640
  chatbot_display = gr.Chatbot(
641
+ label="Conversation", bubble_full_width=True, height=500,
642
+ elem_classes="chat-history", type="messages", show_label=False,
643
+ avatar_images=(None, os.path.join(KOKORO_PATH, "icon.png") if os.path.exists(os.path.join(KOKORO_PATH, "icon.png")) else "https://huggingface.co/spaces/gradio/chatbot-streaming/resolve/main/avatar.png")
 
 
 
 
644
  )
645
+ # This Markdown will only show the *final* status/answer text
646
  answer_status_output = gr.Markdown(value="*Enter a query to start.*", elem_classes="answer-box markdown-content")
647
  audio_player = gr.Audio(label="Voice Response", type="numpy", autoplay=False, show_label=False, elem_classes="audio-player")
648
 
 
649
  with gr.Column(scale=2):
650
  with gr.Column(elem_classes="sources-box"):
651
  gr.Markdown("### Sources")
652
  sources_output_html = gr.HTML(value="<div class='no-sources'>Sources will appear here.</div>")
653
 
 
654
  with gr.Row(elem_classes="examples-container"):
655
  gr.Examples(
656
+ examples=[ "Latest news about renewable energy", "Explain Large Language Models (LLMs)",
657
+ "Symptoms and prevention tips for the flu", "Compare Python and JavaScript",
658
+ "Summarize the Paris Agreement", ],
659
+ inputs=search_input, label="Try these examples:",
 
 
 
 
 
 
660
  )
661
 
662
+ # --- Event Handling Setup (Synchronous) ---
663
  event_inputs = [search_input, chat_history_state, voice_select]
664
+ event_outputs = [ chatbot_display, answer_status_output, sources_output_html,
665
+ audio_player, search_btn ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
 
667
+ # Connect the SYNCHRONOUS handle_interaction function directly
668
  search_btn.click(
669
+ fn=handle_interaction, # Use the synchronous handler
670
  inputs=event_inputs,
671
  outputs=event_outputs
672
  )
673
  search_input.submit(
674
+ fn=handle_interaction, # Use the synchronous handler
675
  inputs=event_inputs,
676
  outputs=event_outputs
677
  )
678
 
679
  # --- Main Execution ---
680
  if __name__ == "__main__":
681
+ print("Starting Gradio application (Synchronous for ZeroGPU)...")
682
+ # Ensure TTS setup thread has a chance to start
683
+ time.sleep(1) # Small delay might help see initial TTS logs
684
  demo.queue(max_size=20).launch(
685
  debug=True,
686
+ share=True,
 
 
687
  )
688
  print("Gradio application stopped.")