sagar007 commited on
Commit
6560c55
·
verified ·
1 Parent(s): ffc273f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -209
app.py CHANGED
@@ -1,6 +1,6 @@
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- # import spaces # Removed as @spaces.GPU is not used with async
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
@@ -16,8 +16,9 @@ from concurrent.futures import ThreadPoolExecutor
16
  import warnings
17
  import traceback # For detailed error logging
18
  import re # For text cleaning
19
- import shutil # For checking sudo
20
  import html # For escaping HTML
 
21
 
22
  # --- Configuration ---
23
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
@@ -27,11 +28,11 @@ MAX_TTS_CHARS = 1000 # Max characters for a single TTS chunk
27
  MAX_NEW_TOKENS = 300
28
  TEMPERATURE = 0.7
29
  TOP_P = 0.95
30
- KOKORO_PATH = 'Kokoro-82M' # Path to TTS model directory
31
 
32
  # --- Initialization ---
33
- # Use a ThreadPoolExecutor for blocking I/O or CPU-bound tasks
34
- executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4) # Use available cores
35
 
36
  # Suppress specific warnings
37
  warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
@@ -40,41 +41,42 @@ warnings.filterwarnings("ignore", message="Backend 'inductor' is not available."
40
  # --- LLM Initialization ---
41
  llm_model: Optional[AutoModelForCausalLM] = None
42
  llm_tokenizer: Optional[AutoTokenizer] = None
43
- llm_device = "cpu" # Default device
44
 
45
  try:
46
- print("Initializing LLM...")
47
  llm_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
48
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
49
 
50
  if torch.cuda.is_available():
51
  llm_device = "cuda"
52
  torch_dtype = torch.float16
53
- device_map = "auto" # Let accelerate handle distribution
54
- print(f"CUDA detected. Loading model with device_map='{device_map}', dtype={torch_dtype}")
55
  else:
56
  llm_device = "cpu"
57
- torch_dtype = torch.float32 # float32 for CPU
58
  device_map = {"": "cpu"}
59
- print(f"CUDA not found. Loading model on CPU with dtype={torch_dtype}")
60
 
61
  llm_model = AutoModelForCausalLM.from_pretrained(
62
  MODEL_NAME,
63
  device_map=device_map,
64
  low_cpu_mem_usage=True,
65
  torch_dtype=torch_dtype,
66
- # attn_implementation="flash_attention_2" # Optional: Uncomment if flash-attn is installed and compatible GPU
67
  )
68
- print(f"LLM loaded successfully. Device map: {llm_model.hf_device_map if hasattr(llm_model, 'hf_device_map') else 'N/A'}")
69
- llm_model.eval() # Set to evaluation mode
 
 
70
 
71
  except Exception as e:
72
- print(f"FATAL: Error initializing LLM model: {str(e)}")
73
  print(traceback.format_exc())
74
- # Depending on environment, you might exit or just disable LLM features
75
  llm_model = None
76
  llm_tokenizer = None
77
- print("LLM features will be unavailable.")
78
 
79
 
80
  # --- TTS Initialization ---
@@ -85,147 +87,189 @@ VOICE_CHOICES = {
85
  '🇺🇸 Nicole': 'af_nicole'
86
  }
87
  TTS_ENABLED = False
88
- tts_model: Optional[Any] = None # Define type more specifically if Kokoro provides it
89
- voicepacks: Dict[str, Any] = {} # Cache voice packs
90
- tts_device = "cpu" # Default device for TTS model
91
 
92
- # Use a lock for thread-safe access during initialization if needed, though Thread ensures sequential execution
93
- # tts_init_lock = threading.Lock()
94
-
95
- def _run_subprocess(cmd: List[str], check: bool = True, cwd: Optional[str] = None) -> subprocess.CompletedProcess:
96
- """Helper to run subprocess and capture output."""
97
  print(f"Running command: {' '.join(cmd)}")
98
  try:
99
- result = subprocess.run(cmd, check=check, capture_output=True, text=True, cwd=cwd)
100
- if result.stdout: print(f"Stdout: {result.stdout.strip()}")
101
- if result.stderr: print(f"Stderr: {result.stderr.strip()}")
 
 
 
 
102
  return result
103
  except FileNotFoundError:
104
- print(f"Error: Command not found - {cmd[0]}")
 
 
 
105
  raise
106
  except subprocess.CalledProcessError as e:
107
- print(f"Error running command: {' '.join(e.cmd)}")
108
- if e.stdout: print(f"Stdout: {e.stdout.strip()}")
109
- if e.stderr: print(f"Stderr: {e.stderr.strip()}")
110
  raise
111
 
 
112
  def setup_tts_task():
113
  """Initializes Kokoro TTS model and dependencies."""
114
  global TTS_ENABLED, tts_model, voicepacks, tts_device
115
  print("[TTS Setup] Starting background initialization...")
116
 
117
- # Determine TTS device
118
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
119
  print(f"[TTS Setup] Target device: {tts_device}")
120
 
121
  can_sudo = shutil.which('sudo') is not None
122
  apt_cmd_prefix = ['sudo'] if can_sudo else []
 
123
 
124
  try:
125
  # 1. Clone Kokoro Repo if needed
126
- if not os.path.exists(KOKORO_PATH):
127
- print(f"[TTS Setup] Cloning repository to {KOKORO_PATH}...")
128
  try:
129
  _run_subprocess(['git', 'lfs', 'install', '--system', '--skip-repo'])
130
  except Exception as lfs_err:
131
- print(f"[TTS Setup] Warning: git lfs install command failed: {lfs_err}. Continuing clone...")
132
- _run_subprocess(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M', KOKORO_PATH])
133
  try:
134
  print("[TTS Setup] Running git lfs pull...")
135
- _run_subprocess(['git', 'lfs', 'pull'], cwd=KOKORO_PATH)
136
  except Exception as lfs_pull_err:
137
  print(f"[TTS Setup] Warning: git lfs pull failed: {lfs_pull_err}")
138
  else:
139
- print(f"[TTS Setup] Directory {KOKORO_PATH} already exists.")
 
 
 
 
 
 
 
140
 
141
  # 2. Install espeak dependency
142
  print("[TTS Setup] Checking/Installing espeak...")
143
  try:
 
144
  _run_subprocess(apt_cmd_prefix + ['apt-get', 'update', '-qq'])
 
145
  _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak-ng'])
146
  print("[TTS Setup] espeak-ng installed or already present.")
147
  except Exception:
148
- print("[TTS Setup] espeak-ng failed, trying espeak...")
149
  try:
 
150
  _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak'])
151
  print("[TTS Setup] espeak installed or already present.")
152
  except Exception as espeak_err:
153
  print(f"[TTS Setup] ERROR: Failed to install both espeak-ng and espeak: {espeak_err}. TTS disabled.")
154
- return # Critical dependency missing
155
 
156
  # 3. Load Kokoro Model and Voices
157
- if os.path.exists(KOKORO_PATH):
158
- sys_path_updated = False
159
- if KOKORO_PATH not in sys.path:
160
- sys.path.append(KOKORO_PATH)
 
 
 
 
 
 
 
 
 
 
161
  sys_path_updated = True
 
 
162
  try:
 
163
  from models import build_model
164
  from kokoro import generate as generate_tts_internal
 
165
 
166
- globals()['build_model'] = build_model # Make available globally
 
167
  globals()['generate_tts_internal'] = generate_tts_internal
168
 
169
- model_file = os.path.join(KOKORO_PATH, 'kokoro-v0_19.pth')
170
  if not os.path.exists(model_file):
171
  print(f"[TTS Setup] ERROR: Model file {model_file} not found. TTS disabled.")
172
  return
173
 
174
  print(f"[TTS Setup] Loading TTS model from {model_file} onto {tts_device}...")
175
  tts_model = build_model(model_file, tts_device)
176
- tts_model.eval() # Set to eval mode
177
  print("[TTS Setup] TTS model loaded.")
178
 
179
  # Load voices
180
  loaded_voices = 0
181
  for voice_name, voice_id in VOICE_CHOICES.items():
182
- voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
183
  if os.path.exists(voice_file_path):
184
  try:
185
  print(f"[TTS Setup] Loading voice: {voice_id} ({voice_name})")
186
- # map_location ensures it loads to the correct device
187
  voicepacks[voice_id] = torch.load(voice_file_path, map_location=tts_device)
188
  loaded_voices += 1
189
  except Exception as e:
190
  print(f"[TTS Setup] Warning: Failed to load voice {voice_id}: {str(e)}")
191
  else:
192
- print(f"[TTS Setup] Info: Voice file {voice_file_path} not found, skipping.")
193
 
194
  if loaded_voices == 0:
195
  print("[TTS Setup] ERROR: No voicepacks could be loaded. TTS disabled.")
196
- tts_model = None # Unload model if no voices
197
  return
198
 
199
  TTS_ENABLED = True
200
  print(f"[TTS Setup] Initialization successful. {loaded_voices} voices loaded. TTS Enabled: {TTS_ENABLED}")
201
 
 
202
  except ImportError as ie:
203
- print(f"[TTS Setup] ERROR: Failed to import Kokoro modules: {ie}. Check clone and path. TTS disabled.")
 
 
204
  except Exception as load_err:
205
- print(f"[TTS Setup] ERROR: Failed loading TTS model/voices: {load_err}. TTS disabled.")
206
  print(traceback.format_exc())
207
  finally:
208
- # Clean up sys.path if modified
209
- if sys_path_updated and KOKORO_PATH in sys.path:
210
- sys.path.remove(KOKORO_PATH)
 
 
 
 
 
 
 
 
 
 
211
  else:
212
- print(f"[TTS Setup] ERROR: {KOKORO_PATH} directory not found. TTS disabled.")
213
 
214
  except Exception as e:
215
  print(f"[TTS Setup] ERROR: Unexpected error during setup: {str(e)}")
216
  print(traceback.format_exc())
217
- # Ensure TTS is marked as disabled
218
- TTS_ENABLED = False
219
  tts_model = None
220
  voicepacks.clear()
221
 
222
- # Start TTS setup in a background thread
223
  print("Starting TTS setup thread...")
224
  tts_setup_thread = threading.Thread(target=setup_tts_task, daemon=True)
225
  tts_setup_thread.start()
226
 
227
 
228
- # --- Core Functions ---
229
 
230
  @lru_cache(maxsize=128)
231
  def get_web_results_sync(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, Any]]:
@@ -244,33 +288,27 @@ def get_web_results_sync(query: str, max_results: int = MAX_SEARCH_RESULTS) -> L
244
  return formatted
245
  except Exception as e:
246
  print(f"[Web Search] Error: {e}")
247
- print(traceback.format_exc())
248
  return []
249
 
250
  def format_llm_prompt(query: str, context: List[Dict[str, Any]]) -> str:
251
  """Formats the prompt for the LLM, including context and instructions."""
252
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
253
  context_str = "\n\n".join(
254
- [f"[{res['id']}] {res['title']}\n{res['snippet']}" for res in context]
255
  ) if context else "No relevant web context found."
256
 
257
- return f"""You are a helpful AI assistant. Answer the user's query based *only* on the provided web search context.
258
- Instructions:
259
- - Synthesize information from the context to answer concisely.
260
- - Cite sources using bracket notation like [1], [2], etc., corresponding to the context IDs.
261
- - If the context is insufficient, state that clearly. Do not add external information.
262
- - Use markdown for formatting.
263
-
264
- Current Time: {current_time}
265
 
266
- Web Context:
267
  ---
268
  {context_str}
269
  ---
270
 
271
- User Query: {query}
272
 
273
- Answer:"""
274
 
275
  def format_sources_html(web_results: List[Dict[str, Any]]) -> str:
276
  """Formats search results into HTML for display."""
@@ -280,7 +318,7 @@ def format_sources_html(web_results: List[Dict[str, Any]]) -> str:
280
  for res in web_results:
281
  title_safe = html.escape(res.get("title", "Source"))
282
  snippet_safe = html.escape(res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else ""))
283
- url = res.get("url", "#")
284
  items_html += f"""
285
  <div class='source-item'>
286
  <div class='source-number'>[{res['id']}]</div>
@@ -295,7 +333,8 @@ def format_sources_html(web_results: List[Dict[str, Any]]) -> str:
295
  async def generate_llm_answer(prompt: str) -> str:
296
  """Generates answer using the loaded LLM (Async Wrapper)."""
297
  if not llm_model or not llm_tokenizer:
298
- return "Error: LLM model is not available."
 
299
 
300
  print(f"[LLM Generate] Requesting generation (prompt length {len(prompt)})...")
301
  start_time = time.time()
@@ -305,12 +344,11 @@ async def generate_llm_answer(prompt: str) -> str:
305
  return_tensors="pt",
306
  padding=True,
307
  truncation=True,
308
- max_length=1024, # Consider model's actual max length
309
  return_attention_mask=True
310
- ).to(llm_model.device) # Ensure inputs are on the same device as model parts
311
 
312
  with torch.inference_mode(), torch.cuda.amp.autocast(enabled=(llm_model.dtype == torch.float16)):
313
- # Run blocking model.generate in the executor thread pool
314
  outputs = await asyncio.get_event_loop().run_in_executor(
315
  executor,
316
  llm_model.generate,
@@ -325,20 +363,12 @@ async def generate_llm_answer(prompt: str) -> str:
325
  num_return_sequences=1
326
  )
327
 
328
- # Decode only newly generated tokens relative to input
329
  output_ids = outputs[0][inputs.input_ids.shape[1]:]
330
  answer_part = llm_tokenizer.decode(output_ids, skip_special_tokens=True).strip()
331
 
332
- # Handle potential empty generation
333
  if not answer_part:
334
- # Sometimes the split method above is needed if the model includes the prompt
335
- full_output = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
336
- answer_marker = "Answer:"
337
- marker_index = full_output.rfind(answer_marker)
338
- if marker_index != -1:
339
- answer_part = full_output[marker_index + len(answer_marker):].strip()
340
- else:
341
- answer_part = "*Model generated an empty response.*" # Fallback message
342
 
343
  end_time = time.time()
344
  print(f"[LLM Generate] Generation complete in {end_time - start_time:.2f}s. Length: {len(answer_part)}")
@@ -347,22 +377,21 @@ async def generate_llm_answer(prompt: str) -> str:
347
  except Exception as e:
348
  print(f"[LLM Generate] Error: {e}")
349
  print(traceback.format_exc())
350
- return f"Error during answer generation: {str(e)}"
351
 
352
  async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple[int, np.ndarray]]:
353
  """Generates speech using the loaded TTS model (Async Wrapper)."""
354
  if not TTS_ENABLED or not tts_model or 'generate_tts_internal' not in globals():
355
  print("[TTS Generate] Skipping: TTS not ready.")
356
  return None
357
- if not text or not text.strip():
358
- print("[TTS Generate] Skipping: Empty text.")
359
  return None
360
 
361
  print(f"[TTS Generate] Requesting speech (length {len(text)}, voice '{voice_id}')...")
362
  start_time = time.time()
363
 
364
  try:
365
- # Verify voicepack availability
366
  actual_voice_id = voice_id
367
  if voice_id not in voicepacks:
368
  print(f"[TTS Generate] Warning: Voice '{voice_id}' not loaded. Trying default 'af'.")
@@ -371,18 +400,23 @@ async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple
371
  print("[TTS Generate] Error: Default voice 'af' also not available.")
372
  return None
373
 
374
- # Clean text for TTS
375
- clean_text = re.sub(r'\[\d+\](\[\d+\])*', '', text) # Remove citations like [1], [2][3]
376
- clean_text = re.sub(r'[\*\#\`]', '', clean_text) # Remove markdown symbols
 
 
 
 
377
  clean_text = ' '.join(clean_text.split()) # Normalize whitespace
378
 
379
- if not clean_text: return None # Skip if empty after cleaning
 
 
380
 
381
- # Truncate if necessary
382
  if len(clean_text) > MAX_TTS_CHARS:
383
- print(f"[TTS Generate] Truncating text from {len(clean_text)} to {MAX_TTS_CHARS} chars.")
384
  clean_text = clean_text[:MAX_TTS_CHARS]
385
- last_punct = max(clean_text.rfind(p) for p in '.?! ')
386
  if last_punct != -1: clean_text = clean_text[:last_punct+1]
387
  clean_text += "..."
388
 
@@ -390,17 +424,13 @@ async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple
390
  gen_func = globals()['generate_tts_internal']
391
  voice_pack_data = voicepacks[actual_voice_id]
392
 
393
- # Run blocking TTS generation in the executor thread pool
394
- # Assuming 'afr' is the correct language code for Kokoro's default voices
395
  audio_data, _ = await asyncio.get_event_loop().run_in_executor(
396
- executor,
397
- gen_func,
398
- tts_model, # The loaded model object
399
- clean_text, # The cleaned text string
400
- voice_pack_data,# The loaded voice pack tensor/dict
401
- 'afr' # Language code (verify this is correct)
402
  )
403
 
 
404
  if isinstance(audio_data, torch.Tensor):
405
  audio_np = audio_data.detach().cpu().numpy()
406
  elif isinstance(audio_data, np.ndarray):
@@ -409,8 +439,7 @@ async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple
409
  print("[TTS Generate] Warning: Unexpected audio data type.")
410
  return None
411
 
412
- # Ensure audio is 1D float32
413
- audio_np = audio_np.flatten().astype(np.float32)
414
 
415
  end_time = time.time()
416
  print(f"[TTS Generate] Audio generated in {end_time - start_time:.2f}s. Shape: {audio_np.shape}")
@@ -427,9 +456,7 @@ def get_voice_id_from_display(voice_display_name: str) -> str:
427
 
428
 
429
  # --- Gradio Interaction Logic ---
430
-
431
- # Define type for chat history using the 'messages' format
432
- ChatHistoryType = List[Dict[str, str]]
433
 
434
  async def handle_interaction(
435
  query: str,
@@ -438,94 +465,84 @@ async def handle_interaction(
438
  ):
439
  """Main async generator function to handle user queries and update Gradio UI."""
440
  print(f"\n--- Handling Query ---")
 
441
  print(f"Query: '{query}', Voice: '{selected_voice_display_name}'")
442
 
443
- if not query or not query.strip():
444
  print("Empty query received.")
445
- # Need to yield the current state for all outputs
446
- yield history, "*Please enter a query.*", "<div class='no-sources'>Enter a query to search.</div>", None, gr.Button(value="Search", interactive=True)
447
  return
448
 
449
- # Append user message to history
450
- current_history = history + [{"role": "user", "content": query}]
451
  # Add placeholder for assistant response
452
- current_history.append({"role": "assistant", "content": "*Searching...*"})
 
 
 
 
 
 
 
453
 
454
  # 1. Initial State: Searching
455
- yield (
456
- current_history,
457
- "*Searching the web...*", # Update answer area
458
- "<div class='searching'><span>Searching the web...</span></div>", # Update sources area
459
- None, # No audio yet
460
- gr.Button(value="Searching...", interactive=False) # Update button state
461
- )
462
 
463
  # 2. Perform Web Search (in executor)
464
  web_results = await asyncio.get_event_loop().run_in_executor(
465
  executor, get_web_results_sync, query
466
  )
467
- sources_html = format_sources_html(web_results)
468
 
469
  # Update state: Generating Answer
470
- current_history[-1]["content"] = "*Generating answer...*" # Update assistant placeholder
471
- yield (
472
- current_history,
473
- "*Generating answer...*", # Update answer area
474
- sources_html, # Show sources
475
- None,
476
- gr.Button(value="Generating...", interactive=False)
477
- )
478
 
479
  # 3. Generate LLM Answer (async)
480
  llm_prompt = format_llm_prompt(query, web_results)
481
  final_answer = await generate_llm_answer(llm_prompt)
 
482
 
483
- # Update assistant message in history with the final answer
484
  current_history[-1]["content"] = final_answer
485
 
486
  # Update state: Generating Audio (if applicable)
487
- yield (
488
- current_history,
489
- final_answer, # Show final answer
490
- sources_html,
491
- None,
492
- gr.Button(value="Audio...", interactive=False) if TTS_ENABLED else gr.Button(value="Search", interactive=True) # Enable search if TTS disabled
493
- )
494
 
495
  # 4. Generate TTS Speech (async)
496
- audio_output_data = None
497
  tts_status_message = ""
498
  if not TTS_ENABLED:
499
  if tts_setup_thread.is_alive():
500
  tts_status_message = "\n\n*(TTS initializing...)*"
501
  else:
502
- tts_status_message = "\n\n*(TTS disabled or failed)*"
503
- elif final_answer and not final_answer.startswith("Error"):
 
 
504
  voice_id = get_voice_id_from_display(selected_voice_display_name)
505
- audio_output_data = await generate_tts_speech(final_answer, voice_id)
506
- if audio_output_data is None:
507
  tts_status_message = "\n\n*(Audio generation failed)*"
508
 
509
  # 5. Final State: Show all results
510
  final_answer_with_status = final_answer + tts_status_message
511
- current_history[-1]["content"] = final_answer_with_status # Update history with status msg too
 
 
 
512
 
513
  print("--- Query Handling Complete ---")
514
- yield (
515
- current_history,
516
- final_answer_with_status, # Show answer + TTS status
517
- sources_html,
518
- audio_output_data, # Output audio data (or None)
519
- gr.Button(value="Search", interactive=True) # Re-enable button
520
- )
521
 
522
 
523
  # --- Gradio UI Definition ---
524
- # (CSS remains largely the same - ensure it targets default Gradio classes if elem_classes was removed)
525
  css = """
526
- /* ... [Your existing refined CSS, but remove selectors using .gradio-examples if you were using it] ... */
527
- /* Example: Style examples container via its parent or default class if needed */
528
- /* .examples-container .gradio-examples { ... } */ /* This might still work depending on structure */
529
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
530
  #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
531
  #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
@@ -542,8 +559,8 @@ css = """
542
  .search-box button:hover { background: #1d4ed8 !important; }
543
  .search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
544
  .results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
545
- .answer-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; color: #1f2937; margin-bottom: 1.5rem; box-shadow: 0 2px 8px rgba(0,0,0,0.05); }
546
- .answer-box p { color: #374151; line-height: 1.7; }
547
  .answer-box code { background: #f3f4f6; border-radius: 4px; padding: 2px 4px; color: #4b5563; font-size: 0.9em; }
548
  .sources-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; }
549
  .sources-box h3 { margin-top: 0; margin-bottom: 1rem; color: #111827; font-size: 1.2rem; }
@@ -555,13 +572,12 @@ css = """
555
  .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}
556
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
557
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
558
- .chat-history { /* Style the chatbot container */ max-height: 400px; overflow-y: auto; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; margin-top: 1rem; scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
559
  .chat-history > div { padding: 1rem; } /* Add padding inside the chatbot display area */
560
  .chat-history::-webkit-scrollbar { width: 6px; }
561
  .chat-history::-webkit-scrollbar-track { background: #f9fafb; }
562
  .chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
563
  .examples-container { background: #f9fafb; border-radius: 8px; padding: 1rem; margin-top: 1rem; border: 1px solid #e5e7eb; }
564
- /* Default styling for example buttons (since elem_classes might not work) */
565
  .examples-container button { background: white !important; border: 1px solid #d1d5db !important; color: #374151 !important; transition: all 0.2s; margin: 4px !important; font-size: 0.9em !important; padding: 6px 12px !important; border-radius: 4px !important; }
566
  .examples-container button:hover { background: #f3f4f6 !important; border-color: #adb5bd !important; }
567
  .markdown-content { color: #374151 !important; font-size: 1rem; line-height: 1.7; }
@@ -578,8 +594,8 @@ css = """
578
  .markdown-content table { border-collapse: collapse !important; width: 100% !important; margin: 1em 0; }
579
  .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
580
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
581
- .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; }
582
- .accordion > .label-wrap { padding: 10px 15px !important; }
583
  .voice-selector { margin: 0; padding: 0; height: 100%; }
584
  .voice-selector div[data-testid="dropdown"] { height: 100% !important; border-radius: 0 !important;}
585
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
@@ -592,7 +608,7 @@ css = """
592
  .no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
593
  @keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
594
  .searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; }
595
- /* Dark Mode Styles (Optional - keep if needed) */
596
  .dark .gradio-container { background-color: #111827 !important; }
597
  .dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
598
  .dark #header h3 { color: #9ca3af; }
@@ -630,8 +646,8 @@ css = """
630
  .dark .markdown-content blockquote { border-left-color: #4b5563 !important; color: #9ca3af !important; }
631
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
632
  .dark .markdown-content th { background: #374151 !important; }
633
- .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; }
634
- .dark .accordion > .label-wrap { color: #d1d5db !important; }
635
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
636
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
637
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
@@ -644,13 +660,11 @@ css = """
644
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
645
  """
646
 
647
- import sys # Needed for sys.path manipulation in TTS setup
648
-
649
  with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
650
- # Use gr.State to store the chat history in the 'messages' format
651
  chat_history_state = gr.State([])
652
 
653
- with gr.Column(): # Main container
654
  # Header
655
  with gr.Column(elem_id="header"):
656
  gr.Markdown("# 🔍 AI Search Assistant")
@@ -658,27 +672,25 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(pri
658
 
659
  # Search Area
660
  with gr.Column(elem_classes="search-container"):
661
- with gr.Row(elem_classes="search-box", equal_height=False):
662
  search_input = gr.Textbox(label="", placeholder="Ask anything...", scale=5, container=False)
663
  voice_select = gr.Dropdown(choices=list(VOICE_CHOICES.keys()), value=list(VOICE_CHOICES.keys())[0], label="", scale=1, min_width=180, container=False, elem_classes="voice-selector")
664
  search_btn = gr.Button("Search", variant="primary", scale=0, min_width=100)
665
 
666
  # Results Area
667
- with gr.Row(elem_classes="results-container", equal_height=False):
668
- # Left Column: Answer & History
669
  with gr.Column(scale=3):
670
- # Chatbot display (uses 'messages' format now)
671
  chatbot_display = gr.Chatbot(
672
  label="Conversation",
673
  bubble_full_width=True,
674
- height=500,
675
  elem_classes="chat-history",
676
- type="messages", # Use the recommended type
677
- avatar_images=(None, os.path.join(KOKORO_PATH, "icon.png") if os.path.exists(os.path.join(KOKORO_PATH, "icon.png")) else None) # Optional: Add avatar for assistant
 
678
  )
679
- # Separate Markdown for status/intermediate answer
680
  answer_status_output = gr.Markdown(value="*Enter a query to start.*", elem_classes="answer-box markdown-content")
681
- # Audio Output
682
  audio_player = gr.Audio(label="Voice Response", type="numpy", autoplay=False, show_label=False, elem_classes="audio-player")
683
 
684
  # Right Column: Sources
@@ -689,7 +701,6 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(pri
689
 
690
  # Examples Area
691
  with gr.Row(elem_classes="examples-container"):
692
- # REMOVED elem_classes from gr.Examples
693
  gr.Examples(
694
  examples=[
695
  "Latest news about renewable energy",
@@ -700,47 +711,54 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(pri
700
  ],
701
  inputs=search_input,
702
  label="Try these examples:",
 
703
  )
704
 
705
  # --- Event Handling Setup ---
706
- # Define the inputs and outputs for the Gradio event triggers
707
  event_inputs = [search_input, chat_history_state, voice_select]
708
  event_outputs = [
709
- chatbot_display, # Updated chat history
710
- answer_status_output, # Status or final answer text
711
- sources_output_html, # Formatted sources
712
- audio_player, # Audio data
713
- search_btn # Button state (enabled/disabled)
714
  ]
715
 
716
- # Create a wrapper to adapt the async generator for Gradio's streaming updates
717
  async def stream_interaction_updates(query, history, voice_display_name):
 
 
 
718
  try:
719
- # Iterate through the states yielded by the handler
720
- async for state_update in handle_interaction(query, history, voice_display_name):
721
- yield state_update # Yield the tuple of output values
 
 
722
  except Exception as e:
723
  print(f"[Gradio Stream] Error during interaction: {e}")
724
  print(traceback.format_exc())
725
- # Yield a final error state to the UI
726
- error_history = history + [{"role":"user", "content":query}, {"role":"assistant", "content":f"*Error: {e}*"}]
727
- yield (
728
  error_history,
729
  f"An error occurred: {e}",
730
  "<div class='error'>Request failed.</div>",
731
  None,
732
- gr.Button(value="Search", interactive=True)
733
  )
734
- finally:
735
- # Clear the text input after processing is complete (or errored out)
736
- # We need to yield the final state *plus* the cleared input
737
- # This requires adding search_input to the outputs list for the event triggers
738
- # For now, let's not clear it automatically to avoid complexity.
739
- # yield (*final_state_tuple, gr.Textbox(value="")) # Example if clearing input
740
- print("[Gradio Stream] Interaction stream finished.")
741
 
 
 
 
 
 
 
 
742
 
743
- # Connect the streaming function to the button click and input submit events
 
744
  search_btn.click(
745
  fn=stream_interaction_updates,
746
  inputs=event_inputs,
@@ -752,10 +770,15 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(pri
752
  outputs=event_outputs
753
  )
754
 
 
755
  if __name__ == "__main__":
756
  print("Starting Gradio application...")
 
 
757
  demo.queue(max_size=20).launch(
758
  debug=True,
759
- share=True,
760
- # server_name="0.0.0.0" # Optional: Bind to all interfaces
761
- )
 
 
 
1
+ # --- Imports ---
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
 
16
  import warnings
17
  import traceback # For detailed error logging
18
  import re # For text cleaning
19
+ import shutil # For checking sudo/file operations
20
  import html # For escaping HTML
21
+ import sys # For sys.path manipulation
22
 
23
  # --- Configuration ---
24
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
 
28
  MAX_NEW_TOKENS = 300
29
  TEMPERATURE = 0.7
30
  TOP_P = 0.95
31
+ KOKORO_PATH = 'Kokoro-82M' # Relative path to TTS model directory
32
 
33
  # --- Initialization ---
34
+ # Thread Pool Executor for blocking tasks
35
+ executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
36
 
37
  # Suppress specific warnings
38
  warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
 
41
  # --- LLM Initialization ---
42
  llm_model: Optional[AutoModelForCausalLM] = None
43
  llm_tokenizer: Optional[AutoTokenizer] = None
44
+ llm_device = "cpu"
45
 
46
  try:
47
+ print("[LLM Init] Initializing Language Model...")
48
  llm_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
49
  llm_tokenizer.pad_token = llm_tokenizer.eos_token
50
 
51
  if torch.cuda.is_available():
52
  llm_device = "cuda"
53
  torch_dtype = torch.float16
54
+ device_map = "auto"
55
+ print(f"[LLM Init] CUDA detected. Loading model with device_map='{device_map}', dtype={torch_dtype}")
56
  else:
57
  llm_device = "cpu"
58
+ torch_dtype = torch.float32
59
  device_map = {"": "cpu"}
60
+ print(f"[LLM Init] CUDA not found. Loading model on CPU with dtype={torch_dtype}")
61
 
62
  llm_model = AutoModelForCausalLM.from_pretrained(
63
  MODEL_NAME,
64
  device_map=device_map,
65
  low_cpu_mem_usage=True,
66
  torch_dtype=torch_dtype,
67
+ # attn_implementation="flash_attention_2" # Optional
68
  )
69
+ # Get the actual device map if using 'auto'
70
+ effective_device_map = llm_model.hf_device_map if hasattr(llm_model, 'hf_device_map') else device_map
71
+ print(f"[LLM Init] LLM loaded successfully. Device map: {effective_device_map}")
72
+ llm_model.eval()
73
 
74
  except Exception as e:
75
+ print(f"[LLM Init] FATAL: Error initializing LLM model: {str(e)}")
76
  print(traceback.format_exc())
 
77
  llm_model = None
78
  llm_tokenizer = None
79
+ print("[LLM Init] LLM features will be unavailable.")
80
 
81
 
82
  # --- TTS Initialization ---
 
87
  '🇺🇸 Nicole': 'af_nicole'
88
  }
89
  TTS_ENABLED = False
90
+ tts_model: Optional[Any] = None
91
+ voicepacks: Dict[str, Any] = {}
92
+ tts_device = "cpu"
93
 
94
+ # Helper for running subprocesses
95
+ def _run_subprocess(cmd: List[str], check: bool = True, cwd: Optional[str] = None, timeout: int = 300) -> subprocess.CompletedProcess:
96
+ """Runs a subprocess command, captures output, and handles errors."""
 
 
97
  print(f"Running command: {' '.join(cmd)}")
98
  try:
99
+ result = subprocess.run(cmd, check=check, capture_output=True, text=True, cwd=cwd, timeout=timeout)
100
+ # Only print output details if check failed or for specific successful commands
101
+ if not check or result.returncode != 0:
102
+ if result.stdout: print(f" Stdout: {result.stdout.strip()}")
103
+ if result.stderr: print(f" Stderr: {result.stderr.strip()}")
104
+ elif result.returncode == 0 and ('clone' in cmd or 'pull' in cmd or 'install' in cmd):
105
+ print(f" Command successful.") # Concise success message
106
  return result
107
  except FileNotFoundError:
108
+ print(f" Error: Command not found - {cmd[0]}")
109
+ raise
110
+ except subprocess.TimeoutExpired:
111
+ print(f" Error: Command timed out - {' '.join(cmd)}")
112
  raise
113
  except subprocess.CalledProcessError as e:
114
+ print(f" Error running command: {' '.join(e.cmd)} (Code: {e.returncode})")
115
+ if e.stdout: print(f" Stdout: {e.stdout.strip()}")
116
+ if e.stderr: print(f" Stderr: {e.stderr.strip()}")
117
  raise
118
 
119
+ # TTS Setup Task (runs in background thread)
120
  def setup_tts_task():
121
  """Initializes Kokoro TTS model and dependencies."""
122
  global TTS_ENABLED, tts_model, voicepacks, tts_device
123
  print("[TTS Setup] Starting background initialization...")
124
 
 
125
  tts_device = "cuda" if torch.cuda.is_available() else "cpu"
126
  print(f"[TTS Setup] Target device: {tts_device}")
127
 
128
  can_sudo = shutil.which('sudo') is not None
129
  apt_cmd_prefix = ['sudo'] if can_sudo else []
130
+ absolute_kokoro_path = os.path.abspath(KOKORO_PATH) # Use absolute path
131
 
132
  try:
133
  # 1. Clone Kokoro Repo if needed
134
+ if not os.path.exists(absolute_kokoro_path):
135
+ print(f"[TTS Setup] Cloning repository to {absolute_kokoro_path}...")
136
  try:
137
  _run_subprocess(['git', 'lfs', 'install', '--system', '--skip-repo'])
138
  except Exception as lfs_err:
139
+ print(f"[TTS Setup] Warning: git lfs install failed: {lfs_err}. Continuing...")
140
+ _run_subprocess(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M', absolute_kokoro_path])
141
  try:
142
  print("[TTS Setup] Running git lfs pull...")
143
+ _run_subprocess(['git', 'lfs', 'pull'], cwd=absolute_kokoro_path)
144
  except Exception as lfs_pull_err:
145
  print(f"[TTS Setup] Warning: git lfs pull failed: {lfs_pull_err}")
146
  else:
147
+ print(f"[TTS Setup] Directory {absolute_kokoro_path} already exists.")
148
+ # Optional: Run git pull and lfs pull to update if needed
149
+ # try:
150
+ # print("[TTS Setup] Updating existing repo...")
151
+ # _run_subprocess(['git', 'pull'], cwd=absolute_kokoro_path)
152
+ # _run_subprocess(['git', 'lfs', 'pull'], cwd=absolute_kokoro_path)
153
+ # except Exception as update_err:
154
+ # print(f"[TTS Setup] Warning: Failed to update repo: {update_err}")
155
 
156
  # 2. Install espeak dependency
157
  print("[TTS Setup] Checking/Installing espeak...")
158
  try:
159
+ # Run update quietly first
160
  _run_subprocess(apt_cmd_prefix + ['apt-get', 'update', '-qq'])
161
+ # Try installing espeak-ng
162
  _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak-ng'])
163
  print("[TTS Setup] espeak-ng installed or already present.")
164
  except Exception:
165
+ print("[TTS Setup] espeak-ng installation failed, trying espeak...")
166
  try:
167
+ # Fallback to legacy espeak
168
  _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak'])
169
  print("[TTS Setup] espeak installed or already present.")
170
  except Exception as espeak_err:
171
  print(f"[TTS Setup] ERROR: Failed to install both espeak-ng and espeak: {espeak_err}. TTS disabled.")
172
+ return # Cannot proceed
173
 
174
  # 3. Load Kokoro Model and Voices
175
+ sys_path_updated = False
176
+ if os.path.exists(absolute_kokoro_path):
177
+ print(f"[TTS Setup] Checking contents of: {absolute_kokoro_path}")
178
+ try:
179
+ dir_contents = os.listdir(absolute_kokoro_path)
180
+ print(f"[TTS Setup] Contents: {dir_contents}")
181
+ if 'models.py' not in dir_contents or 'kokoro.py' not in dir_contents:
182
+ print("[TTS Setup] Warning: Core Kokoro python files ('models.py', 'kokoro.py') might be missing!")
183
+ except OSError as list_err:
184
+ print(f"[TTS Setup] Warning: Could not list directory contents: {list_err}")
185
+
186
+ # Add path temporarily for import
187
+ if absolute_kokoro_path not in sys.path:
188
+ sys.path.insert(0, absolute_kokoro_path) # Add to beginning
189
  sys_path_updated = True
190
+ print(f"[TTS Setup] Temporarily added {absolute_kokoro_path} to sys.path.")
191
+
192
  try:
193
+ print("[TTS Setup] Attempting to import Kokoro modules...")
194
  from models import build_model
195
  from kokoro import generate as generate_tts_internal
196
+ print("[TTS Setup] Kokoro modules imported successfully.")
197
 
198
+ # Make functions globally accessible IF NEEDED (alternative: pass them around)
199
+ globals()['build_model'] = build_model
200
  globals()['generate_tts_internal'] = generate_tts_internal
201
 
202
+ model_file = os.path.join(absolute_kokoro_path, 'kokoro-v0_19.pth')
203
  if not os.path.exists(model_file):
204
  print(f"[TTS Setup] ERROR: Model file {model_file} not found. TTS disabled.")
205
  return
206
 
207
  print(f"[TTS Setup] Loading TTS model from {model_file} onto {tts_device}...")
208
  tts_model = build_model(model_file, tts_device)
209
+ tts_model.eval()
210
  print("[TTS Setup] TTS model loaded.")
211
 
212
  # Load voices
213
  loaded_voices = 0
214
  for voice_name, voice_id in VOICE_CHOICES.items():
215
+ voice_file_path = os.path.join(absolute_kokoro_path, 'voices', f'{voice_id}.pt')
216
  if os.path.exists(voice_file_path):
217
  try:
218
  print(f"[TTS Setup] Loading voice: {voice_id} ({voice_name})")
 
219
  voicepacks[voice_id] = torch.load(voice_file_path, map_location=tts_device)
220
  loaded_voices += 1
221
  except Exception as e:
222
  print(f"[TTS Setup] Warning: Failed to load voice {voice_id}: {str(e)}")
223
  else:
224
+ print(f"[TTS Setup] Info: Voice file {voice_file_path} not found.")
225
 
226
  if loaded_voices == 0:
227
  print("[TTS Setup] ERROR: No voicepacks could be loaded. TTS disabled.")
228
+ tts_model = None # Free memory if no voices
229
  return
230
 
231
  TTS_ENABLED = True
232
  print(f"[TTS Setup] Initialization successful. {loaded_voices} voices loaded. TTS Enabled: {TTS_ENABLED}")
233
 
234
+ # Catch the specific import error
235
  except ImportError as ie:
236
+ print(f"[TTS Setup] ERROR: Failed to import Kokoro modules: {ie}.")
237
+ print(f" Please ensure '{absolute_kokoro_path}' contains 'models.py' and 'kokoro.py'.")
238
+ print(traceback.format_exc())
239
  except Exception as load_err:
240
+ print(f"[TTS Setup] ERROR: Exception during TTS model/voice loading: {load_err}. TTS disabled.")
241
  print(traceback.format_exc())
242
  finally:
243
+ # *** Crucial: Clean up sys.path ***
244
+ if sys_path_updated:
245
+ try:
246
+ if sys.path[0] == absolute_kokoro_path:
247
+ sys.path.pop(0)
248
+ print(f"[TTS Setup] Removed {absolute_kokoro_path} from sys.path.")
249
+ else:
250
+ # It might have been removed elsewhere, or wasn't at index 0
251
+ if absolute_kokoro_path in sys.path:
252
+ sys.path.remove(absolute_kokoro_path)
253
+ print(f"[TTS Setup] Removed {absolute_kokoro_path} from sys.path (was not index 0).")
254
+ except Exception as cleanup_err:
255
+ print(f"[TTS Setup] Warning: Error removing path from sys.path: {cleanup_err}")
256
  else:
257
+ print(f"[TTS Setup] ERROR: Directory {absolute_kokoro_path} not found. TTS disabled.")
258
 
259
  except Exception as e:
260
  print(f"[TTS Setup] ERROR: Unexpected error during setup: {str(e)}")
261
  print(traceback.format_exc())
262
+ TTS_ENABLED = False # Ensure disabled on any top-level error
 
263
  tts_model = None
264
  voicepacks.clear()
265
 
266
+ # Start TTS setup in background
267
  print("Starting TTS setup thread...")
268
  tts_setup_thread = threading.Thread(target=setup_tts_task, daemon=True)
269
  tts_setup_thread.start()
270
 
271
 
272
+ # --- Core Logic Functions ---
273
 
274
  @lru_cache(maxsize=128)
275
  def get_web_results_sync(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, Any]]:
 
288
  return formatted
289
  except Exception as e:
290
  print(f"[Web Search] Error: {e}")
291
+ # Avoid printing full traceback repeatedly for common network errors maybe
292
  return []
293
 
294
  def format_llm_prompt(query: str, context: List[Dict[str, Any]]) -> str:
295
  """Formats the prompt for the LLM, including context and instructions."""
296
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
297
  context_str = "\n\n".join(
298
+ [f"[{res['id']}] {html.escape(res['title'])}\n{html.escape(res['snippet'])}" for res in context]
299
  ) if context else "No relevant web context found."
300
 
301
+ # Using a clear, structured prompt
302
+ return f"""SYSTEM: You are a helpful AI assistant. Answer the user's query based *only* on the provided web search context. Cite sources using bracket notation like [1], [2]. If the context is insufficient, state that clearly. Use markdown for formatting. Do not add external information. Current Time: {current_time}
 
 
 
 
 
 
303
 
304
+ CONTEXT:
305
  ---
306
  {context_str}
307
  ---
308
 
309
+ USER: {html.escape(query)}
310
 
311
+ ASSISTANT:""" # Using ASSISTANT: marker might help some models
312
 
313
  def format_sources_html(web_results: List[Dict[str, Any]]) -> str:
314
  """Formats search results into HTML for display."""
 
318
  for res in web_results:
319
  title_safe = html.escape(res.get("title", "Source"))
320
  snippet_safe = html.escape(res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else ""))
321
+ url = html.escape(res.get("url", "#")) # Escape URL too
322
  items_html += f"""
323
  <div class='source-item'>
324
  <div class='source-number'>[{res['id']}]</div>
 
333
  async def generate_llm_answer(prompt: str) -> str:
334
  """Generates answer using the loaded LLM (Async Wrapper)."""
335
  if not llm_model or not llm_tokenizer:
336
+ print("[LLM Generate] LLM model or tokenizer not available.")
337
+ return "Error: Language Model is not available."
338
 
339
  print(f"[LLM Generate] Requesting generation (prompt length {len(prompt)})...")
340
  start_time = time.time()
 
344
  return_tensors="pt",
345
  padding=True,
346
  truncation=True,
347
+ max_length=1024, # Adjust based on model limits
348
  return_attention_mask=True
349
+ ).to(llm_model.device)
350
 
351
  with torch.inference_mode(), torch.cuda.amp.autocast(enabled=(llm_model.dtype == torch.float16)):
 
352
  outputs = await asyncio.get_event_loop().run_in_executor(
353
  executor,
354
  llm_model.generate,
 
363
  num_return_sequences=1
364
  )
365
 
366
+ # Decode only newly generated tokens
367
  output_ids = outputs[0][inputs.input_ids.shape[1]:]
368
  answer_part = llm_tokenizer.decode(output_ids, skip_special_tokens=True).strip()
369
 
 
370
  if not answer_part:
371
+ answer_part = "*Model generated an empty response.*"
 
 
 
 
 
 
 
372
 
373
  end_time = time.time()
374
  print(f"[LLM Generate] Generation complete in {end_time - start_time:.2f}s. Length: {len(answer_part)}")
 
377
  except Exception as e:
378
  print(f"[LLM Generate] Error: {e}")
379
  print(traceback.format_exc())
380
+ return f"Error during answer generation: Check logs for details." # User-friendly error
381
 
382
  async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple[int, np.ndarray]]:
383
  """Generates speech using the loaded TTS model (Async Wrapper)."""
384
  if not TTS_ENABLED or not tts_model or 'generate_tts_internal' not in globals():
385
  print("[TTS Generate] Skipping: TTS not ready.")
386
  return None
387
+ if not text or not text.strip() or text.startswith("Error:") or text.startswith("*Model generated"):
388
+ print("[TTS Generate] Skipping: Invalid or empty text.")
389
  return None
390
 
391
  print(f"[TTS Generate] Requesting speech (length {len(text)}, voice '{voice_id}')...")
392
  start_time = time.time()
393
 
394
  try:
 
395
  actual_voice_id = voice_id
396
  if voice_id not in voicepacks:
397
  print(f"[TTS Generate] Warning: Voice '{voice_id}' not loaded. Trying default 'af'.")
 
400
  print("[TTS Generate] Error: Default voice 'af' also not available.")
401
  return None
402
 
403
+ # Clean text more thoroughly for TTS
404
+ clean_text = re.sub(r'\[\d+\](\[\d+\])*', '', text) # Remove citations [1], [2][3]
405
+ clean_text = re.sub(r'```.*?```', '', clean_text, flags=re.DOTALL) # Remove code blocks
406
+ clean_text = re.sub(r'`[^`]*`', '', clean_text) # Remove inline code
407
+ clean_text = re.sub(r'^\s*[\*->]\s*', '', clean_text, flags=re.MULTILINE) # Remove list markers/blockquotes at line start
408
+ clean_text = re.sub(r'[\*#_]', '', clean_text) # Remove remaining markdown emphasis/headers
409
+ clean_text = html.unescape(clean_text) # Decode HTML entities
410
  clean_text = ' '.join(clean_text.split()) # Normalize whitespace
411
 
412
+ if not clean_text:
413
+ print("[TTS Generate] Skipping: Text empty after cleaning.")
414
+ return None
415
 
 
416
  if len(clean_text) > MAX_TTS_CHARS:
417
+ print(f"[TTS Generate] Truncating cleaned text from {len(clean_text)} to {MAX_TTS_CHARS} chars.")
418
  clean_text = clean_text[:MAX_TTS_CHARS]
419
+ last_punct = max(clean_text.rfind(p) for p in '.?!; ') # Find reasonable cut-off
420
  if last_punct != -1: clean_text = clean_text[:last_punct+1]
421
  clean_text += "..."
422
 
 
424
  gen_func = globals()['generate_tts_internal']
425
  voice_pack_data = voicepacks[actual_voice_id]
426
 
427
+ # Execute in thread pool
428
+ # Verify the expected language code ('afr', 'eng', etc.) for Kokoro
429
  audio_data, _ = await asyncio.get_event_loop().run_in_executor(
430
+ executor, gen_func, tts_model, clean_text, voice_pack_data, 'afr'
 
 
 
 
 
431
  )
432
 
433
+ # Process output
434
  if isinstance(audio_data, torch.Tensor):
435
  audio_np = audio_data.detach().cpu().numpy()
436
  elif isinstance(audio_data, np.ndarray):
 
439
  print("[TTS Generate] Warning: Unexpected audio data type.")
440
  return None
441
 
442
+ audio_np = audio_np.flatten().astype(np.float32) # Ensure 1D float32
 
443
 
444
  end_time = time.time()
445
  print(f"[TTS Generate] Audio generated in {end_time - start_time:.2f}s. Shape: {audio_np.shape}")
 
456
 
457
 
458
  # --- Gradio Interaction Logic ---
459
+ ChatHistoryType = List[Dict[str, Optional[str]]] # Allow None for content during streaming
 
 
460
 
461
  async def handle_interaction(
462
  query: str,
 
465
  ):
466
  """Main async generator function to handle user queries and update Gradio UI."""
467
  print(f"\n--- Handling Query ---")
468
+ query = query.strip() # Clean input query
469
  print(f"Query: '{query}', Voice: '{selected_voice_display_name}'")
470
 
471
+ if not query:
472
  print("Empty query received.")
473
+ yield history, "*Please enter a non-empty query.*", "<div class='no-sources'>Enter a query to search.</div>", None, gr.Button(value="Search", interactive=True)
 
474
  return
475
 
476
+ # Use 'messages' format: List of {'role': 'user'/'assistant', 'content': '...'}
477
+ current_history: ChatHistoryType = history + [{"role": "user", "content": query}]
478
  # Add placeholder for assistant response
479
+ current_history.append({"role": "assistant", "content": None}) # Content starts as None
480
+
481
+ # Define states to yield
482
+ chatbot_state = current_history
483
+ status_state = "*Searching...*"
484
+ sources_state = "<div class='searching'><span>Searching the web...</span></div>"
485
+ audio_state = None
486
+ button_state = gr.Button(value="Searching...", interactive=False)
487
 
488
  # 1. Initial State: Searching
489
+ current_history[-1]["content"] = status_state # Update placeholder
490
+ yield chatbot_state, status_state, sources_state, audio_state, button_state
 
 
 
 
 
491
 
492
  # 2. Perform Web Search (in executor)
493
  web_results = await asyncio.get_event_loop().run_in_executor(
494
  executor, get_web_results_sync, query
495
  )
496
+ sources_state = format_sources_html(web_results)
497
 
498
  # Update state: Generating Answer
499
+ status_state = "*Generating answer...*"
500
+ button_state = gr.Button(value="Generating...", interactive=False)
501
+ current_history[-1]["content"] = status_state # Update placeholder
502
+ yield chatbot_state, status_state, sources_state, audio_state, button_state
 
 
 
 
503
 
504
  # 3. Generate LLM Answer (async)
505
  llm_prompt = format_llm_prompt(query, web_results)
506
  final_answer = await generate_llm_answer(llm_prompt)
507
+ status_state = final_answer # Now status holds the actual answer
508
 
509
+ # Update assistant message in history fully
510
  current_history[-1]["content"] = final_answer
511
 
512
  # Update state: Generating Audio (if applicable)
513
+ button_state = gr.Button(value="Audio...", interactive=False) if TTS_ENABLED else gr.Button(value="Search", interactive=True)
514
+ yield chatbot_state, status_state, sources_state, audio_state, button_state
 
 
 
 
 
515
 
516
  # 4. Generate TTS Speech (async)
 
517
  tts_status_message = ""
518
  if not TTS_ENABLED:
519
  if tts_setup_thread.is_alive():
520
  tts_status_message = "\n\n*(TTS initializing...)*"
521
  else:
522
+ # Check if setup failed vs just disabled
523
+ # This info isn't easily available here, assume failed/disabled
524
+ tts_status_message = "\n\n*(TTS unavailable)*"
525
+ else:
526
  voice_id = get_voice_id_from_display(selected_voice_display_name)
527
+ audio_state = await generate_tts_speech(final_answer, voice_id) # Returns (rate, data) or None
528
+ if audio_state is None and not final_answer.startswith("Error"): # Don't show TTS fail if LLM failed
529
  tts_status_message = "\n\n*(Audio generation failed)*"
530
 
531
  # 5. Final State: Show all results
532
  final_answer_with_status = final_answer + tts_status_message
533
+ status_state = final_answer_with_status # Update status display
534
+ current_history[-1]["content"] = final_answer_with_status # Update history *again* with status msg
535
+
536
+ button_state = gr.Button(value="Search", interactive=True) # Re-enable button
537
 
538
  print("--- Query Handling Complete ---")
539
+ yield chatbot_state, status_state, sources_state, audio_state, button_state
 
 
 
 
 
 
540
 
541
 
542
  # --- Gradio UI Definition ---
543
+ # (CSS from previous response)
544
  css = """
545
+ /* ... [Your existing refined CSS] ... */
 
 
546
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
547
  #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
548
  #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
 
559
  .search-box button:hover { background: #1d4ed8 !important; }
560
  .search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
561
  .results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
562
+ .answer-box { /* Now used for status/interim text */ background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1rem; color: #1f2937; margin-bottom: 0.5rem; box-shadow: 0 2px 8px rgba(0,0,0,0.05); min-height: 50px;}
563
+ .answer-box p { color: #374151; line-height: 1.7; margin:0;}
564
  .answer-box code { background: #f3f4f6; border-radius: 4px; padding: 2px 4px; color: #4b5563; font-size: 0.9em; }
565
  .sources-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; }
566
  .sources-box h3 { margin-top: 0; margin-bottom: 1rem; color: #111827; font-size: 1.2rem; }
 
572
  .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}
573
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
574
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
575
+ .chat-history { /* Style the chatbot container */ max-height: 500px; overflow-y: auto; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; /* margin-top: 1rem; */ scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
576
  .chat-history > div { padding: 1rem; } /* Add padding inside the chatbot display area */
577
  .chat-history::-webkit-scrollbar { width: 6px; }
578
  .chat-history::-webkit-scrollbar-track { background: #f9fafb; }
579
  .chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
580
  .examples-container { background: #f9fafb; border-radius: 8px; padding: 1rem; margin-top: 1rem; border: 1px solid #e5e7eb; }
 
581
  .examples-container button { background: white !important; border: 1px solid #d1d5db !important; color: #374151 !important; transition: all 0.2s; margin: 4px !important; font-size: 0.9em !important; padding: 6px 12px !important; border-radius: 4px !important; }
582
  .examples-container button:hover { background: #f3f4f6 !important; border-color: #adb5bd !important; }
583
  .markdown-content { color: #374151 !important; font-size: 1rem; line-height: 1.7; }
 
594
  .markdown-content table { border-collapse: collapse !important; width: 100% !important; margin: 1em 0; }
595
  .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
596
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
597
+ /* .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; } */
598
+ /* .accordion > .label-wrap { padding: 10px 15px !important; } */
599
  .voice-selector { margin: 0; padding: 0; height: 100%; }
600
  .voice-selector div[data-testid="dropdown"] { height: 100% !important; border-radius: 0 !important;}
601
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
 
608
  .no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
609
  @keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
610
  .searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; }
611
+ /* Dark Mode Styles */
612
  .dark .gradio-container { background-color: #111827 !important; }
613
  .dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
614
  .dark #header h3 { color: #9ca3af; }
 
646
  .dark .markdown-content blockquote { border-left-color: #4b5563 !important; color: #9ca3af !important; }
647
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
648
  .dark .markdown-content th { background: #374151 !important; }
649
+ /* .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; } */
650
+ /* .dark .accordion > .label-wrap { color: #d1d5db !important; } */
651
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
652
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
653
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
 
660
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
661
  """
662
 
 
 
663
  with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
664
+ # Use gr.State for chat history in 'messages' format
665
  chat_history_state = gr.State([])
666
 
667
+ with gr.Column():
668
  # Header
669
  with gr.Column(elem_id="header"):
670
  gr.Markdown("# 🔍 AI Search Assistant")
 
672
 
673
  # Search Area
674
  with gr.Column(elem_classes="search-container"):
675
+ with gr.Row(elem_classes="search-box"):
676
  search_input = gr.Textbox(label="", placeholder="Ask anything...", scale=5, container=False)
677
  voice_select = gr.Dropdown(choices=list(VOICE_CHOICES.keys()), value=list(VOICE_CHOICES.keys())[0], label="", scale=1, min_width=180, container=False, elem_classes="voice-selector")
678
  search_btn = gr.Button("Search", variant="primary", scale=0, min_width=100)
679
 
680
  # Results Area
681
+ with gr.Row(elem_classes="results-container"):
682
+ # Left Column: Chatbot, Status, Audio
683
  with gr.Column(scale=3):
 
684
  chatbot_display = gr.Chatbot(
685
  label="Conversation",
686
  bubble_full_width=True,
687
+ height=500, # Adjusted height
688
  elem_classes="chat-history",
689
+ type="messages", # IMPORTANT: Use 'messages' format
690
+ show_label=False,
691
+ avatar_images=(None, os.path.join(KOKORO_PATH, "icon.png") if os.path.exists(os.path.join(KOKORO_PATH, "icon.png")) else "https://huggingface.co/spaces/gradio/chatbot-streaming/resolve/main/avatar.png") # User/Assistant avatars
692
  )
 
693
  answer_status_output = gr.Markdown(value="*Enter a query to start.*", elem_classes="answer-box markdown-content")
 
694
  audio_player = gr.Audio(label="Voice Response", type="numpy", autoplay=False, show_label=False, elem_classes="audio-player")
695
 
696
  # Right Column: Sources
 
701
 
702
  # Examples Area
703
  with gr.Row(elem_classes="examples-container"):
 
704
  gr.Examples(
705
  examples=[
706
  "Latest news about renewable energy",
 
711
  ],
712
  inputs=search_input,
713
  label="Try these examples:",
714
+ # elem_classes removed
715
  )
716
 
717
  # --- Event Handling Setup ---
 
718
  event_inputs = [search_input, chat_history_state, voice_select]
719
  event_outputs = [
720
+ chatbot_display, # Output 1: Updated chat history
721
+ answer_status_output, # Output 2: Status/final text
722
+ sources_output_html, # Output 3: Sources HTML
723
+ audio_player, # Output 4: Audio data
724
+ search_btn # Output 5: Button state
725
  ]
726
 
 
727
  async def stream_interaction_updates(query, history, voice_display_name):
728
+ """Wraps the async generator to handle streaming updates and errors."""
729
+ print("[Gradio Stream] Starting interaction...")
730
+ final_state_tuple = None # To store the last successful state
731
  try:
732
+ async for state_update_tuple in handle_interaction(query, history, voice_display_name):
733
+ yield state_update_tuple # Yield the tuple for Gradio to update outputs
734
+ final_state_tuple = state_update_tuple # Keep track of the last state
735
+ print("[Gradio Stream] Interaction completed successfully.")
736
+
737
  except Exception as e:
738
  print(f"[Gradio Stream] Error during interaction: {e}")
739
  print(traceback.format_exc())
740
+ # Construct error state to yield
741
+ error_history = history + [{"role":"user", "content":query}, {"role":"assistant", "content":f"*An error occurred. Please check logs.*"}]
742
+ error_state_tuple = (
743
  error_history,
744
  f"An error occurred: {e}",
745
  "<div class='error'>Request failed.</div>",
746
  None,
747
+ gr.Button(value="Search", interactive=True) # Ensure button is re-enabled
748
  )
749
+ yield error_state_tuple # Yield the error state to UI
750
+ final_state_tuple = error_state_tuple # Store error state as last state
 
 
 
 
 
751
 
752
+ # Optionally clear input ONLY if the interaction finished (success or error)
753
+ # Requires adding search_input to event_outputs and handling the update dict
754
+ # Example (if search_input is the 6th output):
755
+ # if final_state_tuple:
756
+ # yield (*final_state_tuple, gr.Textbox(value=""))
757
+ # else: # Handle case where no state was ever yielded (e.g., immediate empty query return)
758
+ # yield (history, "*Please enter a query.*", "...", None, gr.Button(value="Search", interactive=True), gr.Textbox(value=""))
759
 
760
+
761
+ # Connect the streaming function
762
  search_btn.click(
763
  fn=stream_interaction_updates,
764
  inputs=event_inputs,
 
770
  outputs=event_outputs
771
  )
772
 
773
+ # --- Main Execution ---
774
  if __name__ == "__main__":
775
  print("Starting Gradio application...")
776
+ # Optional: Wait a moment for TTS setup thread to start and potentially print messages
777
+ # time.sleep(1)
778
  demo.queue(max_size=20).launch(
779
  debug=True,
780
+ share=True, # Set to False if not running on Spaces or don't need public link
781
+ # server_name="0.0.0.0", # Uncomment to bind to all network interfaces
782
+ # server_port=7860 # Optional: Specify port
783
+ )
784
+ print("Gradio application stopped.")