sagar007 commited on
Commit
ffc273f
·
verified ·
1 Parent(s): 8652f53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +404 -472
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import spaces # Keep for potential future use or other decorators
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
@@ -8,64 +8,76 @@ from datetime import datetime
8
  import os
9
  import subprocess
10
  import numpy as np
11
- from typing import List, Dict, Tuple, Any
12
  from functools import lru_cache
13
  import asyncio
14
  import threading
15
  from concurrent.futures import ThreadPoolExecutor
16
  import warnings
17
  import traceback # For detailed error logging
18
-
19
- # Suppress specific warnings if needed (optional)
20
- warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
21
- # Suppress another common warning with torch.compile backend
22
- # warnings.filterwarnings("ignore", message="Backend 'inductor' is not available.")
23
 
24
  # --- Configuration ---
25
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
26
  MAX_SEARCH_RESULTS = 5
27
  TTS_SAMPLE_RATE = 24000
28
  MAX_TTS_CHARS = 1000 # Max characters for a single TTS chunk
29
- # GPU_DURATION = 60 # Informational only now, decorator is removed
30
- MAX_NEW_TOKENS = 300 # Increased slightly
31
  TEMPERATURE = 0.7
32
  TOP_P = 0.95
33
  KOKORO_PATH = 'Kokoro-82M' # Path to TTS model directory
34
 
35
  # --- Initialization ---
36
- # Use a ThreadPoolExecutor for potentially blocking I/O or CPU-bound tasks
37
- executor = ThreadPoolExecutor(max_workers=4)
38
 
39
- # Initialize model and tokenizer with better error handling
40
- try:
41
- print("Loading tokenizer...")
42
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
43
- tokenizer.pad_token = tokenizer.eos_token
44
 
45
- print("Loading model...")
46
- # Determine device map based on CUDA availability
47
- device_map = "auto" if torch.cuda.is_available() else {"": "cpu"}
48
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Use float32 on CPU
49
- print(f"Attempting to load model with device_map='{device_map}' and dtype={torch_dtype}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- model = AutoModelForCausalLM.from_pretrained(
52
  MODEL_NAME,
53
  device_map=device_map,
54
- # offload_folder="offload", # Enable if needed for large models and disk space is available
55
- low_cpu_mem_usage=True, # Important for faster loading
56
  torch_dtype=torch_dtype,
57
- # attn_implementation="flash_attention_2" # Optional: requires flash-attn installed, use if available for speedup on compatible GPUs
58
  )
59
- print(f"Model loaded successfully. Device map: {model.hf_device_map}")
60
- # Ensure model is in evaluation mode
61
- model.eval()
62
 
63
  except Exception as e:
64
  print(f"FATAL: Error initializing LLM model: {str(e)}")
65
  print(traceback.format_exc())
66
- raise # Stop execution if model loading fails
 
 
 
67
 
68
- # --- TTS Setup ---
 
69
  VOICE_CHOICES = {
70
  '🇺🇸 Female (Default)': 'af',
71
  '🇺🇸 Bella': 'af_bella',
@@ -73,204 +85,181 @@ VOICE_CHOICES = {
73
  '🇺🇸 Nicole': 'af_nicole'
74
  }
75
  TTS_ENABLED = False
76
- TTS_MODEL = None
77
- VOICEPACKS = {} # Cache voice packs
 
 
 
 
78
 
79
- # Initialize Kokoro TTS in a separate thread to avoid blocking startup
80
- def setup_tts():
81
- global TTS_ENABLED, TTS_MODEL, VOICEPACKS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Check privileges for apt-get
84
  can_sudo = shutil.which('sudo') is not None
 
85
 
86
  try:
87
- # Check if Kokoro already exists
88
  if not os.path.exists(KOKORO_PATH):
89
- print("Cloning Kokoro-82M repository...")
90
- # Install git-lfs if not present (might need sudo/apt)
91
  try:
92
- lfs_install_cmd = ['git', 'lfs', 'install']
93
- subprocess.run(lfs_install_cmd, check=True, capture_output=True, text=True)
94
- except (FileNotFoundError, subprocess.CalledProcessError) as lfs_err:
95
- print(f"Warning: git-lfs command failed: {lfs_err}. Cloning might be slow or incomplete.")
96
-
97
- clone_cmd = ['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M', KOKORO_PATH]
98
- result = subprocess.run(clone_cmd, check=True, capture_output=True, text=True)
99
- print("Kokoro cloned successfully.")
100
- # print(result.stdout) # Can be verbose
101
- # Optionally pull LFS files again (sometimes clone doesn't get them all)
102
  try:
103
- print("Running git lfs pull...")
104
- lfs_pull_cmd = ['git', 'lfs', 'pull']
105
- subprocess.run(lfs_pull_cmd, cwd=KOKORO_PATH, check=True, capture_output=True, text=True)
106
- print("git lfs pull completed.")
107
- except (FileNotFoundError, subprocess.CalledProcessError) as lfs_pull_err:
108
- print(f"Warning: git lfs pull failed: {lfs_pull_err}")
109
-
110
  else:
111
- print(f"{KOKORO_PATH} directory already exists.")
112
-
113
- # Install espeak (essential for phonemization)
114
- print("Attempting to install espeak-ng or espeak...")
115
- apt_update_cmd = ['apt-get', 'update', '-qq']
116
- install_cmd_ng = ['apt-get', 'install', '-y', '-qq', 'espeak-ng']
117
- install_cmd_legacy = ['apt-get', 'install', '-y', '-qq', 'espeak']
118
-
119
- if can_sudo:
120
- apt_update_cmd.insert(0, 'sudo')
121
- install_cmd_ng.insert(0, 'sudo')
122
- install_cmd_legacy.insert(0, 'sudo')
123
 
 
 
124
  try:
125
- print(f"Running: {' '.join(apt_update_cmd)}")
126
- subprocess.run(apt_update_cmd, check=True, capture_output=True)
127
- print(f"Running: {' '.join(install_cmd_ng)}")
128
- subprocess.run(install_cmd_ng, check=True, capture_output=True)
129
- print("espeak-ng installed successfully.")
130
- except (FileNotFoundError, subprocess.CalledProcessError) as ng_err:
131
- print(f"espeak-ng installation failed ({ng_err}), trying espeak...")
132
  try:
133
- print(f"Running: {' '.join(install_cmd_legacy)}")
134
- subprocess.run(install_cmd_legacy, check=True, capture_output=True)
135
- print("espeak installed successfully.")
136
- except (FileNotFoundError, subprocess.CalledProcessError) as legacy_err:
137
- print(f"ERROR: Could not install espeak-ng or espeak: {legacy_err}. TTS functionality will be disabled.")
138
- return # Cannot proceed without espeak
139
-
140
- # Set up Kokoro TTS
141
  if os.path.exists(KOKORO_PATH):
142
- import sys
143
  if KOKORO_PATH not in sys.path:
144
  sys.path.append(KOKORO_PATH)
 
145
  try:
146
  from models import build_model
147
- from kokoro import generate as generate_tts_internal # Avoid name clash
148
 
149
- # Make these functions accessible globally if needed
150
- globals()['build_model'] = build_model
151
  globals()['generate_tts_internal'] = generate_tts_internal
152
 
153
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
154
- print(f"Loading TTS model onto device: {device}")
155
  model_file = os.path.join(KOKORO_PATH, 'kokoro-v0_19.pth')
156
-
157
  if not os.path.exists(model_file):
158
- print(f"Error: TTS model file not found at {model_file}. Attempting git lfs pull again...")
159
- try:
160
- lfs_pull_cmd = ['git', 'lfs', 'pull']
161
- subprocess.run(lfs_pull_cmd, cwd=KOKORO_PATH, check=True, capture_output=True, text=True)
162
- if not os.path.exists(model_file):
163
- print(f"ERROR: TTS model file STILL not found at {model_file} after lfs pull. TTS disabled.")
164
- return
165
- except Exception as lfs_pull_err:
166
- print(f"Error during git lfs pull: {lfs_pull_err}. TTS disabled.")
167
- return
168
-
169
- TTS_MODEL = build_model(model_file, device)
170
- print("TTS model loaded.")
171
-
172
- # Preload voices
173
  for voice_name, voice_id in VOICE_CHOICES.items():
174
  voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
175
  if os.path.exists(voice_file_path):
176
  try:
177
- print(f"Loading voice: {voice_id} ({voice_name})")
178
- # Load using torch.load, map_location handles device placement
179
- VOICEPACKS[voice_id] = torch.load(voice_file_path, map_location=device)
 
180
  except Exception as e:
181
- print(f"Warning: Could not load voice {voice_id}: {str(e)}")
182
  else:
183
- print(f"Info: Voice file {voice_file_path} for '{voice_name}' not found, skipping.")
184
 
185
- if not VOICEPACKS:
186
- print("ERROR: No voicepacks could be loaded. TTS disabled.")
 
187
  return
188
 
189
- # Ensure default 'af' is loaded if possible, even if not explicitly in choices sometimes
190
- if 'af' not in VOICEPACKS:
191
- voice_file_path = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
192
- if os.path.exists(voice_file_path):
193
- try:
194
- print(f"Loading fallback default voice: af")
195
- VOICEPACKS['af'] = torch.load(voice_file_path, map_location=device)
196
- except Exception as e:
197
- print(f"Warning: Could not load fallback default voice 'af': {str(e)}")
198
-
199
  TTS_ENABLED = True
200
- print("TTS setup completed successfully.")
201
 
202
  except ImportError as ie:
203
- print(f"ERROR: Importing Kokoro modules failed: {ie}. Check if {KOKORO_PATH} exists and dependencies are met.")
204
- except Exception as model_load_err:
205
- print(f"ERROR: Loading TTS model or voices failed: {model_load_err}")
206
  print(traceback.format_exc())
207
-
 
 
 
208
  else:
209
- print(f"ERROR: {KOKORO_PATH} directory not found. TTS disabled.")
210
- except subprocess.CalledProcessError as spe:
211
- print(f"ERROR: A subprocess command failed during TTS setup: {spe}")
212
- print(f"Command: {' '.join(spe.cmd)}")
213
- if spe.stderr: print(f"Stderr: {spe.stderr.strip()}")
214
- print("TTS setup failed.")
215
  except Exception as e:
216
- print(f"ERROR: An unexpected error occurred during TTS setup: {str(e)}")
217
  print(traceback.format_exc())
 
218
  TTS_ENABLED = False
 
 
219
 
220
- # Start TTS setup in a separate thread
221
- import shutil
222
- print("Starting TTS setup in background thread...")
223
- tts_thread = threading.Thread(target=setup_tts, daemon=True)
224
- tts_thread.start()
225
 
226
- # --- Search and Generation Functions ---
 
227
 
228
  @lru_cache(maxsize=128)
229
- def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
230
- """Get web search results using DuckDuckGo with caching."""
231
- print(f"[Web Search] Searching for: '{query}' (max_results={max_results})")
232
  try:
233
- # Use DDGS context manager for cleanup
234
  with DDGS() as ddgs:
235
- # Fetch results using ddgs.text()
236
- results = list(ddgs.text(query, max_results=max_results, safesearch='moderate', timelimit='y')) # Limit to past year
237
  print(f"[Web Search] Found {len(results)} results.")
238
- formatted_results = []
239
- for i, result in enumerate(results):
240
- formatted_results.append({
241
- "id": i + 1, # Add simple ID for citation
242
- "title": result.get("title", "No Title Available"),
243
- "snippet": result.get("body", "No Snippet Available"),
244
- "url": result.get("href", "#"),
245
- })
246
- return formatted_results
247
  except Exception as e:
248
  print(f"[Web Search] Error: {e}")
249
  print(traceback.format_exc())
250
  return []
251
 
252
- def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
253
- """Format the prompt with web context for the LLM."""
254
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
 
 
255
 
256
- # Format context with IDs for citation
257
- context_lines = []
258
- if context:
259
- for res in context:
260
- context_lines.append(f"[{res['id']}] {res['title']}\n{res['snippet']}")
261
- context_str = "\n\n".join(context_lines)
262
- else:
263
- context_str = "No web context available."
264
-
265
- # Clear instructions for the model
266
- prompt = f"""You are a helpful AI assistant. Your task is to answer the user's query based *only* on the provided web search context.
267
- Follow these instructions carefully:
268
- 1. Synthesize the information from the context to provide a comprehensive answer.
269
- 2. Cite the sources used in your answer using bracket notation with the source ID, like [1], [2], etc.
270
- 3. If multiple sources support a point, you can cite them together, e.g., [1][3].
271
- 4. Do *not* add information that is not present in the context.
272
- 5. If the context does not contain relevant information to answer the query, clearly state that you cannot answer based on the provided context.
273
- 6. Format the answer clearly using markdown.
274
 
275
  Current Time: {current_time}
276
 
@@ -282,24 +271,17 @@ Web Context:
282
  User Query: {query}
283
 
284
  Answer:"""
285
- # print(f"--- Formatted Prompt ---\n{prompt[:1000]}...\n--- End Prompt ---") # Debugging: Print start of prompt
286
- return prompt
287
 
288
- def format_sources(web_results: List[Dict[str, str]]) -> str:
289
- """Format sources into HTML for display."""
290
  if not web_results:
291
  return "<div class='no-sources'>No sources found for this query.</div>"
292
-
293
- sources_html = "<div class='sources-container'>"
294
  for res in web_results:
295
- title = res.get("title", "Source")
 
296
  url = res.get("url", "#")
297
- snippet = res.get("snippet", "")
298
- # Basic HTML escaping for snippet and title
299
- title_safe = gr. gradio.utils.escape_html(title)
300
- snippet_safe = gr. gradio.utils.escape_html(snippet[:150] + ("..." if len(snippet) > 150 else ""))
301
-
302
- sources_html += f"""
303
  <div class='source-item'>
304
  <div class='source-number'>[{res['id']}]</div>
305
  <div class='source-content'>
@@ -308,154 +290,130 @@ def format_sources(web_results: List[Dict[str, str]]) -> str:
308
  </div>
309
  </div>
310
  """
311
- sources_html += "</div>"
312
- return sources_html
313
 
314
- # --- Core Async Logic ---
 
 
 
315
 
316
- # NOTE: @spaces.GPU decorator is REMOVED because it's incompatible with async def
317
- async def generate_answer(prompt: str) -> str:
318
- """Generate answer using the DeepSeek model (Async Wrapper)."""
319
- print(f"[LLM Generate] Generating answer for prompt (length {len(prompt)})...")
320
  start_time = time.time()
321
  try:
322
- # Tokenize input - ensure it runs on the correct device implicitly via model.device
323
- inputs = tokenizer(
324
  prompt,
325
  return_tensors="pt",
326
  padding=True,
327
  truncation=True,
328
- max_length=1024, # Model's context window might be larger, adjust if known
329
  return_attention_mask=True
330
- ).to(model.device)
331
-
332
- # Use torch.inference_mode() for efficiency
333
- with torch.inference_mode(), torch.cuda.amp.autocast(enabled=(model.dtype == torch.float16)):
334
- # Run model.generate in a separate thread to avoid blocking asyncio event loop
335
- outputs = await asyncio.to_thread(
336
- model.generate,
337
- input_ids=inputs.input_ids,
338
  attention_mask=inputs.attention_mask,
339
  max_new_tokens=MAX_NEW_TOKENS,
340
  temperature=TEMPERATURE,
341
  top_p=TOP_P,
342
- pad_token_id=tokenizer.eos_token_id,
343
- eos_token_id=tokenizer.eos_token_id, # Explicitly set EOS token
344
  do_sample=True,
345
  num_return_sequences=1
346
  )
347
 
348
- # Decode only the newly generated tokens
349
- # output_ids = outputs[0][inputs.input_ids.shape[1]:] # Slice generated part
350
- # answer_part = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
351
-
352
- # Alternative: Decode full output and split (can be less reliable if prompt has "Answer:")
353
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
354
- answer_marker = "Answer:"
355
- marker_index = full_output.rfind(answer_marker) # Use rfind to find the last occurrence
356
- if marker_index != -1:
357
- answer_part = full_output[marker_index + len(answer_marker):].strip()
358
- else:
359
- # Fallback: try to remove the prompt text (less reliable)
360
- prompt_decoded = tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
361
- if full_output.startswith(prompt_decoded):
362
- answer_part = full_output[len(prompt_decoded):].strip()
363
- # Check if the marker is now at the beginning
364
- if answer_part.startswith(answer_marker):
365
- answer_part = answer_part[len(answer_marker):].strip()
366
- else:
367
- print("[LLM Generate] Warning: 'Answer:' marker not found and prompt prefix mismatch. Using full output.")
368
- answer_part = full_output # Use full output as last resort
369
 
370
  end_time = time.time()
371
- print(f"[LLM Generate] Answer generated successfully in {end_time - start_time:.2f}s. Length: {len(answer_part)}")
372
- return answer_part if answer_part else "*Model did not generate a response.*"
373
 
374
  except Exception as e:
375
  print(f"[LLM Generate] Error: {e}")
376
  print(traceback.format_exc())
377
- return f"Error generating answer: {str(e)}"
378
-
379
- # NOTE: @spaces.GPU decorator is REMOVED because it's incompatible with async def
380
- async def generate_speech(text: str, voice_id: str = 'af') -> Tuple[int, np.ndarray] | None:
381
- """Generate speech from text using Kokoro TTS model (Async Wrapper)."""
382
- global TTS_MODEL, TTS_ENABLED, VOICEPACKS
383
 
384
- if not TTS_ENABLED or TTS_MODEL is None:
385
- print("[TTS Generate] Skipping: TTS not enabled or model not loaded.")
386
- return None
387
- if 'generate_tts_internal' not in globals():
388
- print("[TTS Generate] Skipping: TTS generation function not found.")
389
  return None
390
  if not text or not text.strip():
391
- print("[TTS Generate] Skipping: Empty text provided.")
392
  return None
393
 
394
- print(f"[TTS Generate] Requesting speech for text (length {len(text)}) with voice '{voice_id}'")
395
  start_time = time.time()
396
 
397
  try:
398
- device = TTS_MODEL.device
399
-
400
- # Ensure voicepack is loaded
401
- if voice_id not in VOICEPACKS:
402
- print(f"[TTS Generate] Warning: Voice '{voice_id}' not preloaded. Attempting fallback.")
403
- # Attempt fallback to default 'af' if available
404
- voice_id = 'af'
405
- if 'af' not in VOICEPACKS:
406
- print("[TTS Generate] Error: Default voice 'af' also not available. Cannot generate audio.")
407
  return None
408
- print("[TTS Generate] Using default voice 'af'.")
409
-
410
- # Clean the text (simple cleaning)
411
- # Remove markdown citations like [1], [2][3] etc.
412
- clean_text = re.sub(r'\[\d+\](\[\d+\])*', '', text)
413
- # Remove other common markdown artifacts
414
- clean_text = clean_text.replace('*', '').replace('#', '').replace('`', '')
415
- # Remove excessive whitespace
416
- clean_text = ' '.join(clean_text.split())
417
-
418
- if not clean_text.strip():
419
- print("[TTS Generate] Skipping: Text is empty after cleaning.")
420
- return None
421
 
422
- # Truncate if too long
 
 
 
 
 
 
 
423
  if len(clean_text) > MAX_TTS_CHARS:
424
- print(f"[TTS Generate] Warning: Text too long ({len(clean_text)} chars), truncating to {MAX_TTS_CHARS}.")
425
  clean_text = clean_text[:MAX_TTS_CHARS]
426
- # Find last punctuation or space for cleaner cut
427
- cut_off = max(clean_text.rfind('.'), clean_text.rfind('?'), clean_text.rfind('!'), clean_text.rfind(' '))
428
- if cut_off != -1:
429
- clean_text = clean_text[:cut_off+1]
430
- clean_text += "..." # Indicate truncation
431
 
432
  print(f"[TTS Generate] Generating audio for: '{clean_text[:100]}...'")
433
  gen_func = globals()['generate_tts_internal']
 
434
 
435
- # Run the blocking TTS generation in the thread pool executor
 
436
  audio_data, _ = await asyncio.get_event_loop().run_in_executor(
437
  executor,
438
  gen_func,
439
- TTS_MODEL,
440
- clean_text,
441
- VOICEPACKS[voice_id],
442
- 'afr' # Language code for Kokoro (check if 'afr' or 'eng' or other is correct for your voices)
443
  )
444
 
445
  if isinstance(audio_data, torch.Tensor):
446
- # Move tensor to CPU before converting to numpy if it's not already
447
  audio_np = audio_data.detach().cpu().numpy()
448
  elif isinstance(audio_data, np.ndarray):
449
  audio_np = audio_data
450
  else:
451
- print("[TTS Generate] Warning: Unexpected audio data type received.")
452
  return None
453
 
 
 
 
454
  end_time = time.time()
455
- print(f"[TTS Generate] Audio generated successfully in {end_time - start_time:.2f}s. Shape: {audio_np.shape}")
456
- # Ensure it's 1D array
457
- if audio_np.ndim > 1:
458
- audio_np = audio_np.flatten()
459
  return (TTS_SAMPLE_RATE, audio_np)
460
 
461
  except Exception as e:
@@ -463,108 +421,111 @@ async def generate_speech(text: str, voice_id: str = 'af') -> Tuple[int, np.ndar
463
  print(traceback.format_exc())
464
  return None
465
 
466
- # Helper to get voice ID from display name
467
- def get_voice_id(voice_display_name: str) -> str:
468
  """Maps the user-friendly voice name to the internal voice ID."""
469
- return VOICE_CHOICES.get(voice_display_name, 'af') # Default to 'af' if not found
 
 
 
470
 
471
- # --- Main Processing Logic (Async Generator) ---
472
- import re # Import regex for cleaning
473
 
474
- async def process_query_async(query: str, history: List[List[str]], selected_voice_display_name: str):
475
- """Asynchronously process user query: search -> generate answer -> generate speech"""
476
- print(f"\n--- New Query Processing ---")
 
 
 
 
477
  print(f"Query: '{query}', Voice: '{selected_voice_display_name}'")
478
 
479
  if not query or not query.strip():
480
  print("Empty query received.")
481
- yield (
482
- "Please enter a query.", "", gr.Button(value="Search", interactive=True), history, None
483
- )
484
  return
485
 
486
- if history is None: history = []
487
- # Append user query to history immediately for display
488
- current_history = history + [[query, None]] # Placeholder for assistant response
 
489
 
490
- # 1. Initial state: Searching
491
  yield (
492
- "*Searching the web...*",
493
- "<div class='searching'><span>Searching the web...</span></div>", # Added span for CSS animation
494
- gr.Button(value="Searching...", interactive=False), # Disable button
495
  current_history,
496
- None
 
 
 
497
  )
498
 
499
- # 2. Perform Web Search (non-blocking)
500
- loop = asyncio.get_event_loop()
501
- web_results = await loop.run_in_executor(executor, get_web_results, query)
502
- sources_html = format_sources(web_results)
 
503
 
504
- # Update state: Analyzing results
 
505
  yield (
506
- "*Analyzing search results and generating answer...*",
507
- sources_html,
508
- gr.Button(value="Generating...", interactive=False),
509
- current_history, # History still shows user query, assistant response is pending
510
- None
511
  )
512
 
513
- # 3. Generate Answer (non-blocking, potentially on GPU)
514
- prompt = format_prompt(query, web_results)
515
- final_answer = await generate_answer(prompt) # This is already async
516
 
517
- # Update history with the final answer BEFORE generating audio
518
- current_history[-1][1] = final_answer
519
 
520
- # Update state: Answer generated, preparing audio
521
  yield (
522
- final_answer,
 
523
  sources_html,
524
- gr.Button(value="Audio...", interactive=False),
525
- current_history, # Now history includes the answer
526
- None
527
  )
528
 
529
- # 4. Generate Speech (non-blocking, potentially on GPU)
530
- audio = None
531
- tts_message = ""
532
- if not tts_thread.is_alive() and not TTS_ENABLED:
533
- print("[TTS Status] TTS setup failed or is disabled.")
534
- tts_message = "\n\n*(TTS is disabled or failed to initialize)*"
535
- elif tts_thread.is_alive():
536
- print("[TTS Status] TTS is still initializing in the background.")
537
- tts_message = "\n\n*(TTS is still initializing, audio may be delayed or unavailable)*"
538
- elif TTS_ENABLED:
539
- voice_id = get_voice_id(selected_voice_display_name)
540
- # Only generate audio if the answer generation was successful
541
- if not final_answer.startswith("Error"):
542
- audio = await generate_speech(final_answer, voice_id) # This is already async
543
- if audio is None:
544
- print(f"[TTS Status] Audio generation failed for voice '{voice_id}'.")
545
- tts_message = f"\n\n*(Audio generation failed)*"
546
- else:
547
- print("[TTS Status] Audio generated successfully.")
548
  else:
549
- print("[TTS Status] Skipping audio generation due to answer error.")
550
- tts_message = "\n\n*(Audio skipped due to answer generation error)*"
551
-
552
-
553
- # 5. Final state: Show everything
554
- print("--- Query Processing Complete ---")
 
 
 
 
 
 
555
  yield (
556
- final_answer + tts_message,
 
557
  sources_html,
558
- gr.Button(value="Search", interactive=True), # Re-enable button
559
- current_history, # Final history state
560
- audio
561
  )
562
 
563
 
564
- # --- Gradio Interface ---
565
- # (CSS remains the same as your previous version)
566
  css = """
567
- /* ... [Your existing refined CSS] ... */
 
 
568
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
569
  #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
570
  #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
@@ -589,20 +550,19 @@ css = """
589
  .sources-container { margin-top: 0; }
590
  .source-item { display: flex; padding: 10px 0; margin: 0; border-bottom: 1px solid #f3f4f6; transition: background-color 0.2s; }
591
  .source-item:last-child { border-bottom: none; }
592
- /* .source-item:hover { background-color: #f9fafb; } */
593
  .source-number { font-weight: bold; margin-right: 12px; color: #6b7280; width: 20px; text-align: right; flex-shrink: 0;}
594
  .source-content { flex: 1; min-width: 0;} /* Allow content to shrink */
595
  .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}
596
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
597
- .source-date { color: #6b7280; font-size: 0.8em; margin-left: 8px; }
598
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
599
- .chat-history { max-height: 400px; overflow-y: auto; padding: 1rem; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; margin-top: 1rem; scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
 
600
  .chat-history::-webkit-scrollbar { width: 6px; }
601
  .chat-history::-webkit-scrollbar-track { background: #f9fafb; }
602
  .chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
603
  .examples-container { background: #f9fafb; border-radius: 8px; padding: 1rem; margin-top: 1rem; border: 1px solid #e5e7eb; }
604
- .examples-container .gradio-examples { gap: 8px !important; } /* Target examples component */
605
- .examples-container button { background: white !important; border: 1px solid #d1d5db !important; color: #374151 !important; transition: all 0.2s; margin: 0 !important; font-size: 0.9em !important; padding: 6px 12px !important; }
606
  .examples-container button:hover { background: #f3f4f6 !important; border-color: #adb5bd !important; }
607
  .markdown-content { color: #374151 !important; font-size: 1rem; line-height: 1.7; }
608
  .markdown-content h1, .markdown-content h2, .markdown-content h3 { color: #111827 !important; margin-top: 1.2em !important; margin-bottom: 0.6em !important; font-weight: 600; }
@@ -619,7 +579,7 @@ css = """
619
  .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
620
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
621
  .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; }
622
- .accordion > .label-wrap { padding: 10px 15px !important; } /* Style accordion header */
623
  .voice-selector { margin: 0; padding: 0; height: 100%; }
624
  .voice-selector div[data-testid="dropdown"] { height: 100% !important; border-radius: 0 !important;}
625
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
@@ -632,7 +592,7 @@ css = """
632
  .no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
633
  @keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
634
  .searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; }
635
- /* Dark Mode Styles */
636
  .dark .gradio-container { background-color: #111827 !important; }
637
  .dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
638
  .dark #header h3 { color: #9ca3af; }
@@ -654,7 +614,7 @@ css = """
654
  .dark .source-title { color: #60a5fa; }
655
  .dark .source-title:hover { color: #93c5fd; }
656
  .dark .source-snippet { color: #d1d5db; }
657
- .dark .chat-history { background: #374151; border-color: #4b5563; scrollbar-color: #4b5563 #374151; color: #d1d5db;} /* Ensure chat text is visible */
658
  .dark .chat-history::-webkit-scrollbar-track { background: #374151; }
659
  .dark .chat-history::-webkit-scrollbar-thumb { background-color: #4b5563; }
660
  .dark .examples-container { background: #374151; border-color: #4b5563; }
@@ -671,11 +631,11 @@ css = """
671
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
672
  .dark .markdown-content th { background: #374151 !important; }
673
  .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; }
674
- .dark .accordion > .label-wrap { color: #d1d5db !important; } /* Accordion label color */
675
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
676
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
677
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
678
- .dark .audio-player audio::-webkit-media-controls-panel { background-color: #374151; } /* Style audio player controls */
679
  .dark .audio-player audio::-webkit-media-controls-play-button { color: #d1d5db; }
680
  .dark .audio-player audio::-webkit-media-controls-current-time-display { color: #9ca3af; }
681
  .dark .audio-player audio::-webkit-media-controls-time-remaining-display { color: #9ca3af; }
@@ -684,146 +644,118 @@ css = """
684
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
685
  """
686
 
 
 
687
  with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
688
- # chat_history state persists across interactions for a single user session
689
- chat_history = gr.State([])
690
 
691
- with gr.Column(): # Main container for vertical layout
692
- # Header Section
693
  with gr.Column(elem_id="header"):
694
  gr.Markdown("# 🔍 AI Search Assistant")
695
  gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
696
 
697
- # Search Input and Controls Section
698
  with gr.Column(elem_classes="search-container"):
699
- with gr.Row(elem_classes="search-box", equal_height=False): # Use Row for horizontal elements
700
- search_input = gr.Textbox(
701
- label="",
702
- placeholder="Ask anything...",
703
- scale=5, # Takes more horizontal space
704
- container=False, # Important for direct styling within Row
705
- elem_classes="gradio-textbox"
706
- )
707
- voice_select = gr.Dropdown(
708
- choices=list(VOICE_CHOICES.keys()),
709
- value=list(VOICE_CHOICES.keys())[0], # Default voice display name
710
- label="", # Visually hidden label
711
- scale=1, # Takes less space
712
- min_width=180, # Fixed width for dropdown
713
- container=False, # Important
714
- elem_classes="voice-selector gradio-dropdown"
715
- )
716
- search_btn = gr.Button(
717
- "Search",
718
- variant="primary",
719
- scale=0, # Minimal width needed for text
720
- min_width=100,
721
- elem_classes="gradio-button"
722
- )
723
 
724
- # Results Display Section (using Columns for side-by-side layout)
725
  with gr.Row(elem_classes="results-container", equal_height=False):
726
- # Left Column: Answer and Chat History
727
- with gr.Column(scale=3): # Takes 3 parts of the width
728
- with gr.Column(elem_classes="answer-box"):
729
- answer_output = gr.Markdown(value="*Your answer will appear here...*", elem_classes="markdown-content")
730
- # Audio player below the answer text
731
- audio_output = gr.Audio(
732
- label="Voice Response",
733
- type="numpy", # Expects (rate, numpy_array) tuple
734
- autoplay=False, # Don't autoplay by default
735
- show_label=False, # Hide the "Voice Response" label visually
736
- elem_classes="audio-player"
737
- )
738
-
739
- with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
740
- chat_history_display = gr.Chatbot(
741
- label="Conversation",
742
- bubble_full_width=True, # Bubbles take full width
743
- height=400,
744
- elem_classes="chat-history"
745
- )
746
 
747
  # Right Column: Sources
748
- with gr.Column(scale=2): # Takes 2 parts of the width
749
- with gr.Column(elem_classes="sources-box"):
750
  gr.Markdown("### Sources")
751
- sources_output = gr.HTML(value="<div class='no-sources'>Sources will appear here after searching.</div>")
752
 
753
- # Example Prompts Section
754
  with gr.Row(elem_classes="examples-container"):
 
755
  gr.Examples(
756
  examples=[
757
  "Latest news about renewable energy",
758
- "Explain the concept of Large Language Models (LLMs)",
759
- "What are the symptoms and prevention tips for the flu?",
760
  "Compare Python and JavaScript for web development",
761
- "Summarize the main points of the Paris Agreement on climate change",
762
  ],
763
- inputs=search_input, # Clicking example populates this input
764
  label="Try these examples:",
765
- elem_classes="gradio-examples" # Add class for potential styling
766
  )
767
 
768
- # --- Event Handling ---
769
- async def handle_interaction(query, history, voice_display_name):
770
- """Wrapper to handle the async generator and update outputs."""
771
- print(f"[Interaction] Handling query: '{query}'")
772
- outputs = { # Dictionary to hold the latest state of outputs
773
- "answer": "...",
774
- "sources": "...",
775
- "button": gr.Button(value="Search", interactive=True),
776
- "history": history,
777
- "audio": None
778
- }
779
- try:
780
- # Iterate through the updates yielded by the async generator
781
- async for update_tuple in process_query_async(query, history, voice_display_name):
782
- # Unpack the tuple
783
- ans_out, src_out, btn_state, hist_display, aud_out = update_tuple
784
- # Update the outputs dictionary
785
- outputs["answer"] = ans_out
786
- outputs["sources"] = src_out
787
- outputs["button"] = btn_state # Can be a gr.Button update dict or object
788
- outputs["history"] = hist_display
789
- outputs["audio"] = aud_out
790
- # Yield the current state of all outputs
791
- yield outputs["answer"], outputs["sources"], outputs["button"], outputs["history"], outputs["audio"]
792
- except Exception as e:
793
- print(f"[Interaction] Error: {e}")
794
  print(traceback.format_exc())
795
- error_message = f"An unexpected error occurred: {e}"
796
- # Provide a final error state update
797
- final_error_history = history + [[query, f"*Error: {error_message}*"]] if query else history
798
  yield (
799
- error_message,
800
- "<div class='error'>Error processing request. Please check logs or try again.</div>",
801
- gr.Button(value="Search", interactive=True), # Re-enable button on error
802
- final_error_history,
803
- None
804
  )
 
 
 
 
 
 
 
805
 
806
- # Connect the handle_interaction function to the button click and input submit events
807
- outputs_list = [answer_output, sources_output, search_btn, chat_history_display, audio_output]
808
- inputs_list = [search_input, chat_history, voice_select] # Pass the dropdown component itself
809
 
 
810
  search_btn.click(
811
- fn=handle_interaction,
812
- inputs=inputs_list,
813
- outputs=outputs_list
814
  )
815
-
816
  search_input.submit(
817
- fn=handle_interaction,
818
- inputs=inputs_list,
819
- outputs=outputs_list
820
  )
821
 
822
  if __name__ == "__main__":
823
  print("Starting Gradio application...")
824
- # Launch the app with queuing enabled for handling multiple users
825
  demo.queue(max_size=20).launch(
826
- debug=True, # Enable Gradio debug mode for more logs
827
- share=True, # Create a public link (useful for Spaces)
828
- # server_name="0.0.0.0" # Bind to all interfaces if running locally and need external access
829
  )
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ # import spaces # Removed as @spaces.GPU is not used with async
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
 
8
  import os
9
  import subprocess
10
  import numpy as np
11
+ from typing import List, Dict, Tuple, Any, Optional, Union
12
  from functools import lru_cache
13
  import asyncio
14
  import threading
15
  from concurrent.futures import ThreadPoolExecutor
16
  import warnings
17
  import traceback # For detailed error logging
18
+ import re # For text cleaning
19
+ import shutil # For checking sudo
20
+ import html # For escaping HTML
 
 
21
 
22
  # --- Configuration ---
23
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
24
  MAX_SEARCH_RESULTS = 5
25
  TTS_SAMPLE_RATE = 24000
26
  MAX_TTS_CHARS = 1000 # Max characters for a single TTS chunk
27
+ MAX_NEW_TOKENS = 300
 
28
  TEMPERATURE = 0.7
29
  TOP_P = 0.95
30
  KOKORO_PATH = 'Kokoro-82M' # Path to TTS model directory
31
 
32
  # --- Initialization ---
33
+ # Use a ThreadPoolExecutor for blocking I/O or CPU-bound tasks
34
+ executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4) # Use available cores
35
 
36
+ # Suppress specific warnings
37
+ warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
38
+ warnings.filterwarnings("ignore", message="Backend 'inductor' is not available.")
 
 
39
 
40
+ # --- LLM Initialization ---
41
+ llm_model: Optional[AutoModelForCausalLM] = None
42
+ llm_tokenizer: Optional[AutoTokenizer] = None
43
+ llm_device = "cpu" # Default device
44
+
45
+ try:
46
+ print("Initializing LLM...")
47
+ llm_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
48
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
49
+
50
+ if torch.cuda.is_available():
51
+ llm_device = "cuda"
52
+ torch_dtype = torch.float16
53
+ device_map = "auto" # Let accelerate handle distribution
54
+ print(f"CUDA detected. Loading model with device_map='{device_map}', dtype={torch_dtype}")
55
+ else:
56
+ llm_device = "cpu"
57
+ torch_dtype = torch.float32 # float32 for CPU
58
+ device_map = {"": "cpu"}
59
+ print(f"CUDA not found. Loading model on CPU with dtype={torch_dtype}")
60
 
61
+ llm_model = AutoModelForCausalLM.from_pretrained(
62
  MODEL_NAME,
63
  device_map=device_map,
64
+ low_cpu_mem_usage=True,
 
65
  torch_dtype=torch_dtype,
66
+ # attn_implementation="flash_attention_2" # Optional: Uncomment if flash-attn is installed and compatible GPU
67
  )
68
+ print(f"LLM loaded successfully. Device map: {llm_model.hf_device_map if hasattr(llm_model, 'hf_device_map') else 'N/A'}")
69
+ llm_model.eval() # Set to evaluation mode
 
70
 
71
  except Exception as e:
72
  print(f"FATAL: Error initializing LLM model: {str(e)}")
73
  print(traceback.format_exc())
74
+ # Depending on environment, you might exit or just disable LLM features
75
+ llm_model = None
76
+ llm_tokenizer = None
77
+ print("LLM features will be unavailable.")
78
 
79
+
80
+ # --- TTS Initialization ---
81
  VOICE_CHOICES = {
82
  '🇺🇸 Female (Default)': 'af',
83
  '🇺🇸 Bella': 'af_bella',
 
85
  '🇺🇸 Nicole': 'af_nicole'
86
  }
87
  TTS_ENABLED = False
88
+ tts_model: Optional[Any] = None # Define type more specifically if Kokoro provides it
89
+ voicepacks: Dict[str, Any] = {} # Cache voice packs
90
+ tts_device = "cpu" # Default device for TTS model
91
+
92
+ # Use a lock for thread-safe access during initialization if needed, though Thread ensures sequential execution
93
+ # tts_init_lock = threading.Lock()
94
 
95
+ def _run_subprocess(cmd: List[str], check: bool = True, cwd: Optional[str] = None) -> subprocess.CompletedProcess:
96
+ """Helper to run subprocess and capture output."""
97
+ print(f"Running command: {' '.join(cmd)}")
98
+ try:
99
+ result = subprocess.run(cmd, check=check, capture_output=True, text=True, cwd=cwd)
100
+ if result.stdout: print(f"Stdout: {result.stdout.strip()}")
101
+ if result.stderr: print(f"Stderr: {result.stderr.strip()}")
102
+ return result
103
+ except FileNotFoundError:
104
+ print(f"Error: Command not found - {cmd[0]}")
105
+ raise
106
+ except subprocess.CalledProcessError as e:
107
+ print(f"Error running command: {' '.join(e.cmd)}")
108
+ if e.stdout: print(f"Stdout: {e.stdout.strip()}")
109
+ if e.stderr: print(f"Stderr: {e.stderr.strip()}")
110
+ raise
111
+
112
+ def setup_tts_task():
113
+ """Initializes Kokoro TTS model and dependencies."""
114
+ global TTS_ENABLED, tts_model, voicepacks, tts_device
115
+ print("[TTS Setup] Starting background initialization...")
116
+
117
+ # Determine TTS device
118
+ tts_device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ print(f"[TTS Setup] Target device: {tts_device}")
120
 
 
121
  can_sudo = shutil.which('sudo') is not None
122
+ apt_cmd_prefix = ['sudo'] if can_sudo else []
123
 
124
  try:
125
+ # 1. Clone Kokoro Repo if needed
126
  if not os.path.exists(KOKORO_PATH):
127
+ print(f"[TTS Setup] Cloning repository to {KOKORO_PATH}...")
 
128
  try:
129
+ _run_subprocess(['git', 'lfs', 'install', '--system', '--skip-repo'])
130
+ except Exception as lfs_err:
131
+ print(f"[TTS Setup] Warning: git lfs install command failed: {lfs_err}. Continuing clone...")
132
+ _run_subprocess(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M', KOKORO_PATH])
 
 
 
 
 
 
133
  try:
134
+ print("[TTS Setup] Running git lfs pull...")
135
+ _run_subprocess(['git', 'lfs', 'pull'], cwd=KOKORO_PATH)
136
+ except Exception as lfs_pull_err:
137
+ print(f"[TTS Setup] Warning: git lfs pull failed: {lfs_pull_err}")
 
 
 
138
  else:
139
+ print(f"[TTS Setup] Directory {KOKORO_PATH} already exists.")
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # 2. Install espeak dependency
142
+ print("[TTS Setup] Checking/Installing espeak...")
143
  try:
144
+ _run_subprocess(apt_cmd_prefix + ['apt-get', 'update', '-qq'])
145
+ _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak-ng'])
146
+ print("[TTS Setup] espeak-ng installed or already present.")
147
+ except Exception:
148
+ print("[TTS Setup] espeak-ng failed, trying espeak...")
 
 
149
  try:
150
+ _run_subprocess(apt_cmd_prefix + ['apt-get', 'install', '-y', '-qq', 'espeak'])
151
+ print("[TTS Setup] espeak installed or already present.")
152
+ except Exception as espeak_err:
153
+ print(f"[TTS Setup] ERROR: Failed to install both espeak-ng and espeak: {espeak_err}. TTS disabled.")
154
+ return # Critical dependency missing
155
+
156
+ # 3. Load Kokoro Model and Voices
 
157
  if os.path.exists(KOKORO_PATH):
158
+ sys_path_updated = False
159
  if KOKORO_PATH not in sys.path:
160
  sys.path.append(KOKORO_PATH)
161
+ sys_path_updated = True
162
  try:
163
  from models import build_model
164
+ from kokoro import generate as generate_tts_internal
165
 
166
+ globals()['build_model'] = build_model # Make available globally
 
167
  globals()['generate_tts_internal'] = generate_tts_internal
168
 
 
 
169
  model_file = os.path.join(KOKORO_PATH, 'kokoro-v0_19.pth')
 
170
  if not os.path.exists(model_file):
171
+ print(f"[TTS Setup] ERROR: Model file {model_file} not found. TTS disabled.")
172
+ return
173
+
174
+ print(f"[TTS Setup] Loading TTS model from {model_file} onto {tts_device}...")
175
+ tts_model = build_model(model_file, tts_device)
176
+ tts_model.eval() # Set to eval mode
177
+ print("[TTS Setup] TTS model loaded.")
178
+
179
+ # Load voices
180
+ loaded_voices = 0
 
 
 
 
 
181
  for voice_name, voice_id in VOICE_CHOICES.items():
182
  voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
183
  if os.path.exists(voice_file_path):
184
  try:
185
+ print(f"[TTS Setup] Loading voice: {voice_id} ({voice_name})")
186
+ # map_location ensures it loads to the correct device
187
+ voicepacks[voice_id] = torch.load(voice_file_path, map_location=tts_device)
188
+ loaded_voices += 1
189
  except Exception as e:
190
+ print(f"[TTS Setup] Warning: Failed to load voice {voice_id}: {str(e)}")
191
  else:
192
+ print(f"[TTS Setup] Info: Voice file {voice_file_path} not found, skipping.")
193
 
194
+ if loaded_voices == 0:
195
+ print("[TTS Setup] ERROR: No voicepacks could be loaded. TTS disabled.")
196
+ tts_model = None # Unload model if no voices
197
  return
198
 
 
 
 
 
 
 
 
 
 
 
199
  TTS_ENABLED = True
200
+ print(f"[TTS Setup] Initialization successful. {loaded_voices} voices loaded. TTS Enabled: {TTS_ENABLED}")
201
 
202
  except ImportError as ie:
203
+ print(f"[TTS Setup] ERROR: Failed to import Kokoro modules: {ie}. Check clone and path. TTS disabled.")
204
+ except Exception as load_err:
205
+ print(f"[TTS Setup] ERROR: Failed loading TTS model/voices: {load_err}. TTS disabled.")
206
  print(traceback.format_exc())
207
+ finally:
208
+ # Clean up sys.path if modified
209
+ if sys_path_updated and KOKORO_PATH in sys.path:
210
+ sys.path.remove(KOKORO_PATH)
211
  else:
212
+ print(f"[TTS Setup] ERROR: {KOKORO_PATH} directory not found. TTS disabled.")
213
+
 
 
 
 
214
  except Exception as e:
215
+ print(f"[TTS Setup] ERROR: Unexpected error during setup: {str(e)}")
216
  print(traceback.format_exc())
217
+ # Ensure TTS is marked as disabled
218
  TTS_ENABLED = False
219
+ tts_model = None
220
+ voicepacks.clear()
221
 
222
+ # Start TTS setup in a background thread
223
+ print("Starting TTS setup thread...")
224
+ tts_setup_thread = threading.Thread(target=setup_tts_task, daemon=True)
225
+ tts_setup_thread.start()
 
226
 
227
+
228
+ # --- Core Functions ---
229
 
230
  @lru_cache(maxsize=128)
231
+ def get_web_results_sync(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, Any]]:
232
+ """Synchronous web search function with caching."""
233
+ print(f"[Web Search] Searching (sync): '{query}' (max_results={max_results})")
234
  try:
 
235
  with DDGS() as ddgs:
236
+ results = list(ddgs.text(query, max_results=max_results, safesearch='moderate', timelimit='y'))
 
237
  print(f"[Web Search] Found {len(results)} results.")
238
+ formatted = [{
239
+ "id": i + 1,
240
+ "title": res.get("title", "No Title"),
241
+ "snippet": res.get("body", "No Snippet"),
242
+ "url": res.get("href", "#"),
243
+ } for i, res in enumerate(results)]
244
+ return formatted
 
 
245
  except Exception as e:
246
  print(f"[Web Search] Error: {e}")
247
  print(traceback.format_exc())
248
  return []
249
 
250
+ def format_llm_prompt(query: str, context: List[Dict[str, Any]]) -> str:
251
+ """Formats the prompt for the LLM, including context and instructions."""
252
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
253
+ context_str = "\n\n".join(
254
+ [f"[{res['id']}] {res['title']}\n{res['snippet']}" for res in context]
255
+ ) if context else "No relevant web context found."
256
 
257
+ return f"""You are a helpful AI assistant. Answer the user's query based *only* on the provided web search context.
258
+ Instructions:
259
+ - Synthesize information from the context to answer concisely.
260
+ - Cite sources using bracket notation like [1], [2], etc., corresponding to the context IDs.
261
+ - If the context is insufficient, state that clearly. Do not add external information.
262
+ - Use markdown for formatting.
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  Current Time: {current_time}
265
 
 
271
  User Query: {query}
272
 
273
  Answer:"""
 
 
274
 
275
+ def format_sources_html(web_results: List[Dict[str, Any]]) -> str:
276
+ """Formats search results into HTML for display."""
277
  if not web_results:
278
  return "<div class='no-sources'>No sources found for this query.</div>"
279
+ items_html = ""
 
280
  for res in web_results:
281
+ title_safe = html.escape(res.get("title", "Source"))
282
+ snippet_safe = html.escape(res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else ""))
283
  url = res.get("url", "#")
284
+ items_html += f"""
 
 
 
 
 
285
  <div class='source-item'>
286
  <div class='source-number'>[{res['id']}]</div>
287
  <div class='source-content'>
 
290
  </div>
291
  </div>
292
  """
293
+ return f"<div class='sources-container'>{items_html}</div>"
 
294
 
295
+ async def generate_llm_answer(prompt: str) -> str:
296
+ """Generates answer using the loaded LLM (Async Wrapper)."""
297
+ if not llm_model or not llm_tokenizer:
298
+ return "Error: LLM model is not available."
299
 
300
+ print(f"[LLM Generate] Requesting generation (prompt length {len(prompt)})...")
 
 
 
301
  start_time = time.time()
302
  try:
303
+ inputs = llm_tokenizer(
 
304
  prompt,
305
  return_tensors="pt",
306
  padding=True,
307
  truncation=True,
308
+ max_length=1024, # Consider model's actual max length
309
  return_attention_mask=True
310
+ ).to(llm_model.device) # Ensure inputs are on the same device as model parts
311
+
312
+ with torch.inference_mode(), torch.cuda.amp.autocast(enabled=(llm_model.dtype == torch.float16)):
313
+ # Run blocking model.generate in the executor thread pool
314
+ outputs = await asyncio.get_event_loop().run_in_executor(
315
+ executor,
316
+ llm_model.generate,
317
+ inputs.input_ids,
318
  attention_mask=inputs.attention_mask,
319
  max_new_tokens=MAX_NEW_TOKENS,
320
  temperature=TEMPERATURE,
321
  top_p=TOP_P,
322
+ pad_token_id=llm_tokenizer.eos_token_id,
323
+ eos_token_id=llm_tokenizer.eos_token_id,
324
  do_sample=True,
325
  num_return_sequences=1
326
  )
327
 
328
+ # Decode only newly generated tokens relative to input
329
+ output_ids = outputs[0][inputs.input_ids.shape[1]:]
330
+ answer_part = llm_tokenizer.decode(output_ids, skip_special_tokens=True).strip()
331
+
332
+ # Handle potential empty generation
333
+ if not answer_part:
334
+ # Sometimes the split method above is needed if the model includes the prompt
335
+ full_output = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
336
+ answer_marker = "Answer:"
337
+ marker_index = full_output.rfind(answer_marker)
338
+ if marker_index != -1:
339
+ answer_part = full_output[marker_index + len(answer_marker):].strip()
340
+ else:
341
+ answer_part = "*Model generated an empty response.*" # Fallback message
 
 
 
 
 
 
 
342
 
343
  end_time = time.time()
344
+ print(f"[LLM Generate] Generation complete in {end_time - start_time:.2f}s. Length: {len(answer_part)}")
345
+ return answer_part
346
 
347
  except Exception as e:
348
  print(f"[LLM Generate] Error: {e}")
349
  print(traceback.format_exc())
350
+ return f"Error during answer generation: {str(e)}"
 
 
 
 
 
351
 
352
+ async def generate_tts_speech(text: str, voice_id: str = 'af') -> Optional[Tuple[int, np.ndarray]]:
353
+ """Generates speech using the loaded TTS model (Async Wrapper)."""
354
+ if not TTS_ENABLED or not tts_model or 'generate_tts_internal' not in globals():
355
+ print("[TTS Generate] Skipping: TTS not ready.")
 
356
  return None
357
  if not text or not text.strip():
358
+ print("[TTS Generate] Skipping: Empty text.")
359
  return None
360
 
361
+ print(f"[TTS Generate] Requesting speech (length {len(text)}, voice '{voice_id}')...")
362
  start_time = time.time()
363
 
364
  try:
365
+ # Verify voicepack availability
366
+ actual_voice_id = voice_id
367
+ if voice_id not in voicepacks:
368
+ print(f"[TTS Generate] Warning: Voice '{voice_id}' not loaded. Trying default 'af'.")
369
+ actual_voice_id = 'af'
370
+ if 'af' not in voicepacks:
371
+ print("[TTS Generate] Error: Default voice 'af' also not available.")
 
 
372
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
+ # Clean text for TTS
375
+ clean_text = re.sub(r'\[\d+\](\[\d+\])*', '', text) # Remove citations like [1], [2][3]
376
+ clean_text = re.sub(r'[\*\#\`]', '', clean_text) # Remove markdown symbols
377
+ clean_text = ' '.join(clean_text.split()) # Normalize whitespace
378
+
379
+ if not clean_text: return None # Skip if empty after cleaning
380
+
381
+ # Truncate if necessary
382
  if len(clean_text) > MAX_TTS_CHARS:
383
+ print(f"[TTS Generate] Truncating text from {len(clean_text)} to {MAX_TTS_CHARS} chars.")
384
  clean_text = clean_text[:MAX_TTS_CHARS]
385
+ last_punct = max(clean_text.rfind(p) for p in '.?! ')
386
+ if last_punct != -1: clean_text = clean_text[:last_punct+1]
387
+ clean_text += "..."
 
 
388
 
389
  print(f"[TTS Generate] Generating audio for: '{clean_text[:100]}...'")
390
  gen_func = globals()['generate_tts_internal']
391
+ voice_pack_data = voicepacks[actual_voice_id]
392
 
393
+ # Run blocking TTS generation in the executor thread pool
394
+ # Assuming 'afr' is the correct language code for Kokoro's default voices
395
  audio_data, _ = await asyncio.get_event_loop().run_in_executor(
396
  executor,
397
  gen_func,
398
+ tts_model, # The loaded model object
399
+ clean_text, # The cleaned text string
400
+ voice_pack_data,# The loaded voice pack tensor/dict
401
+ 'afr' # Language code (verify this is correct)
402
  )
403
 
404
  if isinstance(audio_data, torch.Tensor):
 
405
  audio_np = audio_data.detach().cpu().numpy()
406
  elif isinstance(audio_data, np.ndarray):
407
  audio_np = audio_data
408
  else:
409
+ print("[TTS Generate] Warning: Unexpected audio data type.")
410
  return None
411
 
412
+ # Ensure audio is 1D float32
413
+ audio_np = audio_np.flatten().astype(np.float32)
414
+
415
  end_time = time.time()
416
+ print(f"[TTS Generate] Audio generated in {end_time - start_time:.2f}s. Shape: {audio_np.shape}")
 
 
 
417
  return (TTS_SAMPLE_RATE, audio_np)
418
 
419
  except Exception as e:
 
421
  print(traceback.format_exc())
422
  return None
423
 
424
+ def get_voice_id_from_display(voice_display_name: str) -> str:
 
425
  """Maps the user-friendly voice name to the internal voice ID."""
426
+ return VOICE_CHOICES.get(voice_display_name, 'af') # Default to 'af'
427
+
428
+
429
+ # --- Gradio Interaction Logic ---
430
 
431
+ # Define type for chat history using the 'messages' format
432
+ ChatHistoryType = List[Dict[str, str]]
433
 
434
+ async def handle_interaction(
435
+ query: str,
436
+ history: ChatHistoryType,
437
+ selected_voice_display_name: str
438
+ ):
439
+ """Main async generator function to handle user queries and update Gradio UI."""
440
+ print(f"\n--- Handling Query ---")
441
  print(f"Query: '{query}', Voice: '{selected_voice_display_name}'")
442
 
443
  if not query or not query.strip():
444
  print("Empty query received.")
445
+ # Need to yield the current state for all outputs
446
+ yield history, "*Please enter a query.*", "<div class='no-sources'>Enter a query to search.</div>", None, gr.Button(value="Search", interactive=True)
 
447
  return
448
 
449
+ # Append user message to history
450
+ current_history = history + [{"role": "user", "content": query}]
451
+ # Add placeholder for assistant response
452
+ current_history.append({"role": "assistant", "content": "*Searching...*"})
453
 
454
+ # 1. Initial State: Searching
455
  yield (
 
 
 
456
  current_history,
457
+ "*Searching the web...*", # Update answer area
458
+ "<div class='searching'><span>Searching the web...</span></div>", # Update sources area
459
+ None, # No audio yet
460
+ gr.Button(value="Searching...", interactive=False) # Update button state
461
  )
462
 
463
+ # 2. Perform Web Search (in executor)
464
+ web_results = await asyncio.get_event_loop().run_in_executor(
465
+ executor, get_web_results_sync, query
466
+ )
467
+ sources_html = format_sources_html(web_results)
468
 
469
+ # Update state: Generating Answer
470
+ current_history[-1]["content"] = "*Generating answer...*" # Update assistant placeholder
471
  yield (
472
+ current_history,
473
+ "*Generating answer...*", # Update answer area
474
+ sources_html, # Show sources
475
+ None,
476
+ gr.Button(value="Generating...", interactive=False)
477
  )
478
 
479
+ # 3. Generate LLM Answer (async)
480
+ llm_prompt = format_llm_prompt(query, web_results)
481
+ final_answer = await generate_llm_answer(llm_prompt)
482
 
483
+ # Update assistant message in history with the final answer
484
+ current_history[-1]["content"] = final_answer
485
 
486
+ # Update state: Generating Audio (if applicable)
487
  yield (
488
+ current_history,
489
+ final_answer, # Show final answer
490
  sources_html,
491
+ None,
492
+ gr.Button(value="Audio...", interactive=False) if TTS_ENABLED else gr.Button(value="Search", interactive=True) # Enable search if TTS disabled
 
493
  )
494
 
495
+ # 4. Generate TTS Speech (async)
496
+ audio_output_data = None
497
+ tts_status_message = ""
498
+ if not TTS_ENABLED:
499
+ if tts_setup_thread.is_alive():
500
+ tts_status_message = "\n\n*(TTS initializing...)*"
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  else:
502
+ tts_status_message = "\n\n*(TTS disabled or failed)*"
503
+ elif final_answer and not final_answer.startswith("Error"):
504
+ voice_id = get_voice_id_from_display(selected_voice_display_name)
505
+ audio_output_data = await generate_tts_speech(final_answer, voice_id)
506
+ if audio_output_data is None:
507
+ tts_status_message = "\n\n*(Audio generation failed)*"
508
+
509
+ # 5. Final State: Show all results
510
+ final_answer_with_status = final_answer + tts_status_message
511
+ current_history[-1]["content"] = final_answer_with_status # Update history with status msg too
512
+
513
+ print("--- Query Handling Complete ---")
514
  yield (
515
+ current_history,
516
+ final_answer_with_status, # Show answer + TTS status
517
  sources_html,
518
+ audio_output_data, # Output audio data (or None)
519
+ gr.Button(value="Search", interactive=True) # Re-enable button
 
520
  )
521
 
522
 
523
+ # --- Gradio UI Definition ---
524
+ # (CSS remains largely the same - ensure it targets default Gradio classes if elem_classes was removed)
525
  css = """
526
+ /* ... [Your existing refined CSS, but remove selectors using .gradio-examples if you were using it] ... */
527
+ /* Example: Style examples container via its parent or default class if needed */
528
+ /* .examples-container .gradio-examples { ... } */ /* This might still work depending on structure */
529
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
530
  #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
531
  #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
 
550
  .sources-container { margin-top: 0; }
551
  .source-item { display: flex; padding: 10px 0; margin: 0; border-bottom: 1px solid #f3f4f6; transition: background-color 0.2s; }
552
  .source-item:last-child { border-bottom: none; }
 
553
  .source-number { font-weight: bold; margin-right: 12px; color: #6b7280; width: 20px; text-align: right; flex-shrink: 0;}
554
  .source-content { flex: 1; min-width: 0;} /* Allow content to shrink */
555
  .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}
556
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
 
557
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
558
+ .chat-history { /* Style the chatbot container */ max-height: 400px; overflow-y: auto; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; margin-top: 1rem; scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
559
+ .chat-history > div { padding: 1rem; } /* Add padding inside the chatbot display area */
560
  .chat-history::-webkit-scrollbar { width: 6px; }
561
  .chat-history::-webkit-scrollbar-track { background: #f9fafb; }
562
  .chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
563
  .examples-container { background: #f9fafb; border-radius: 8px; padding: 1rem; margin-top: 1rem; border: 1px solid #e5e7eb; }
564
+ /* Default styling for example buttons (since elem_classes might not work) */
565
+ .examples-container button { background: white !important; border: 1px solid #d1d5db !important; color: #374151 !important; transition: all 0.2s; margin: 4px !important; font-size: 0.9em !important; padding: 6px 12px !important; border-radius: 4px !important; }
566
  .examples-container button:hover { background: #f3f4f6 !important; border-color: #adb5bd !important; }
567
  .markdown-content { color: #374151 !important; font-size: 1rem; line-height: 1.7; }
568
  .markdown-content h1, .markdown-content h2, .markdown-content h3 { color: #111827 !important; margin-top: 1.2em !important; margin-bottom: 0.6em !important; font-weight: 600; }
 
579
  .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
580
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
581
  .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; }
582
+ .accordion > .label-wrap { padding: 10px 15px !important; }
583
  .voice-selector { margin: 0; padding: 0; height: 100%; }
584
  .voice-selector div[data-testid="dropdown"] { height: 100% !important; border-radius: 0 !important;}
585
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
 
592
  .no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
593
  @keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
594
  .searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; }
595
+ /* Dark Mode Styles (Optional - keep if needed) */
596
  .dark .gradio-container { background-color: #111827 !important; }
597
  .dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
598
  .dark #header h3 { color: #9ca3af; }
 
614
  .dark .source-title { color: #60a5fa; }
615
  .dark .source-title:hover { color: #93c5fd; }
616
  .dark .source-snippet { color: #d1d5db; }
617
+ .dark .chat-history { background: #374151; border-color: #4b5563; scrollbar-color: #4b5563 #374151; color: #d1d5db;}
618
  .dark .chat-history::-webkit-scrollbar-track { background: #374151; }
619
  .dark .chat-history::-webkit-scrollbar-thumb { background-color: #4b5563; }
620
  .dark .examples-container { background: #374151; border-color: #4b5563; }
 
631
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
632
  .dark .markdown-content th { background: #374151 !important; }
633
  .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; }
634
+ .dark .accordion > .label-wrap { color: #d1d5db !important; }
635
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
636
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
637
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
638
+ .dark .audio-player audio::-webkit-media-controls-panel { background-color: #374151; }
639
  .dark .audio-player audio::-webkit-media-controls-play-button { color: #d1d5db; }
640
  .dark .audio-player audio::-webkit-media-controls-current-time-display { color: #9ca3af; }
641
  .dark .audio-player audio::-webkit-media-controls-time-remaining-display { color: #9ca3af; }
 
644
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
645
  """
646
 
647
+ import sys # Needed for sys.path manipulation in TTS setup
648
+
649
  with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
650
+ # Use gr.State to store the chat history in the 'messages' format
651
+ chat_history_state = gr.State([])
652
 
653
+ with gr.Column(): # Main container
654
+ # Header
655
  with gr.Column(elem_id="header"):
656
  gr.Markdown("# 🔍 AI Search Assistant")
657
  gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
658
 
659
+ # Search Area
660
  with gr.Column(elem_classes="search-container"):
661
+ with gr.Row(elem_classes="search-box", equal_height=False):
662
+ search_input = gr.Textbox(label="", placeholder="Ask anything...", scale=5, container=False)
663
+ voice_select = gr.Dropdown(choices=list(VOICE_CHOICES.keys()), value=list(VOICE_CHOICES.keys())[0], label="", scale=1, min_width=180, container=False, elem_classes="voice-selector")
664
+ search_btn = gr.Button("Search", variant="primary", scale=0, min_width=100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
+ # Results Area
667
  with gr.Row(elem_classes="results-container", equal_height=False):
668
+ # Left Column: Answer & History
669
+ with gr.Column(scale=3):
670
+ # Chatbot display (uses 'messages' format now)
671
+ chatbot_display = gr.Chatbot(
672
+ label="Conversation",
673
+ bubble_full_width=True,
674
+ height=500,
675
+ elem_classes="chat-history",
676
+ type="messages", # Use the recommended type
677
+ avatar_images=(None, os.path.join(KOKORO_PATH, "icon.png") if os.path.exists(os.path.join(KOKORO_PATH, "icon.png")) else None) # Optional: Add avatar for assistant
678
+ )
679
+ # Separate Markdown for status/intermediate answer
680
+ answer_status_output = gr.Markdown(value="*Enter a query to start.*", elem_classes="answer-box markdown-content")
681
+ # Audio Output
682
+ audio_player = gr.Audio(label="Voice Response", type="numpy", autoplay=False, show_label=False, elem_classes="audio-player")
 
 
 
 
 
683
 
684
  # Right Column: Sources
685
+ with gr.Column(scale=2):
686
+ with gr.Column(elem_classes="sources-box"):
687
  gr.Markdown("### Sources")
688
+ sources_output_html = gr.HTML(value="<div class='no-sources'>Sources will appear here.</div>")
689
 
690
+ # Examples Area
691
  with gr.Row(elem_classes="examples-container"):
692
+ # REMOVED elem_classes from gr.Examples
693
  gr.Examples(
694
  examples=[
695
  "Latest news about renewable energy",
696
+ "Explain Large Language Models (LLMs)",
697
+ "Symptoms and prevention tips for the flu",
698
  "Compare Python and JavaScript for web development",
699
+ "Summarize the main points of the Paris Agreement",
700
  ],
701
+ inputs=search_input,
702
  label="Try these examples:",
 
703
  )
704
 
705
+ # --- Event Handling Setup ---
706
+ # Define the inputs and outputs for the Gradio event triggers
707
+ event_inputs = [search_input, chat_history_state, voice_select]
708
+ event_outputs = [
709
+ chatbot_display, # Updated chat history
710
+ answer_status_output, # Status or final answer text
711
+ sources_output_html, # Formatted sources
712
+ audio_player, # Audio data
713
+ search_btn # Button state (enabled/disabled)
714
+ ]
715
+
716
+ # Create a wrapper to adapt the async generator for Gradio's streaming updates
717
+ async def stream_interaction_updates(query, history, voice_display_name):
718
+ try:
719
+ # Iterate through the states yielded by the handler
720
+ async for state_update in handle_interaction(query, history, voice_display_name):
721
+ yield state_update # Yield the tuple of output values
722
+ except Exception as e:
723
+ print(f"[Gradio Stream] Error during interaction: {e}")
 
 
 
 
 
 
 
724
  print(traceback.format_exc())
725
+ # Yield a final error state to the UI
726
+ error_history = history + [{"role":"user", "content":query}, {"role":"assistant", "content":f"*Error: {e}*"}]
 
727
  yield (
728
+ error_history,
729
+ f"An error occurred: {e}",
730
+ "<div class='error'>Request failed.</div>",
731
+ None,
732
+ gr.Button(value="Search", interactive=True)
733
  )
734
+ finally:
735
+ # Clear the text input after processing is complete (or errored out)
736
+ # We need to yield the final state *plus* the cleared input
737
+ # This requires adding search_input to the outputs list for the event triggers
738
+ # For now, let's not clear it automatically to avoid complexity.
739
+ # yield (*final_state_tuple, gr.Textbox(value="")) # Example if clearing input
740
+ print("[Gradio Stream] Interaction stream finished.")
741
 
 
 
 
742
 
743
+ # Connect the streaming function to the button click and input submit events
744
  search_btn.click(
745
+ fn=stream_interaction_updates,
746
+ inputs=event_inputs,
747
+ outputs=event_outputs
748
  )
 
749
  search_input.submit(
750
+ fn=stream_interaction_updates,
751
+ inputs=event_inputs,
752
+ outputs=event_outputs
753
  )
754
 
755
  if __name__ == "__main__":
756
  print("Starting Gradio application...")
 
757
  demo.queue(max_size=20).launch(
758
+ debug=True,
759
+ share=True,
760
+ # server_name="0.0.0.0" # Optional: Bind to all interfaces
761
  )