sagar007 commited on
Commit
8652f53
·
verified ·
1 Parent(s): 3d63694

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +376 -250
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import spaces
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
@@ -14,21 +14,28 @@ import asyncio
14
  import threading
15
  from concurrent.futures import ThreadPoolExecutor
16
  import warnings
 
17
 
18
  # Suppress specific warnings if needed (optional)
19
  warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
 
 
20
 
21
  # --- Configuration ---
22
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
23
  MAX_SEARCH_RESULTS = 5
24
  TTS_SAMPLE_RATE = 24000
25
- MAX_TTS_CHARS = 1000 # Reduced for faster testing, adjust as needed
26
- GPU_DURATION = 60 # Increased duration for longer tasks like TTS
27
- MAX_NEW_TOKENS = 256
28
  TEMPERATURE = 0.7
29
  TOP_P = 0.95
 
30
 
31
  # --- Initialization ---
 
 
 
32
  # Initialize model and tokenizer with better error handling
33
  try:
34
  print("Loading tokenizer...")
@@ -39,21 +46,24 @@ try:
39
  # Determine device map based on CUDA availability
40
  device_map = "auto" if torch.cuda.is_available() else {"": "cpu"}
41
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Use float32 on CPU
 
42
 
43
  model = AutoModelForCausalLM.from_pretrained(
44
  MODEL_NAME,
45
  device_map=device_map,
46
- # offload_folder="offload", # Only use offload if really needed and configured
47
- low_cpu_mem_usage=True,
48
- torch_dtype=torch_dtype
 
49
  )
50
- print(f"Model loaded on device map: {model.hf_device_map}")
51
- print("Model and tokenizer loaded successfully")
 
 
52
  except Exception as e:
53
- print(f"Error initializing model: {str(e)}")
54
- # If running in Spaces, maybe try loading to CPU as fallback?
55
- # For now, just raise the error.
56
- raise
57
 
58
  # --- TTS Setup ---
59
  VOICE_CHOICES = {
@@ -65,47 +75,66 @@ VOICE_CHOICES = {
65
  TTS_ENABLED = False
66
  TTS_MODEL = None
67
  VOICEPACKS = {} # Cache voice packs
68
- KOKORO_PATH = 'Kokoro-82M'
69
 
70
  # Initialize Kokoro TTS in a separate thread to avoid blocking startup
71
  def setup_tts():
72
  global TTS_ENABLED, TTS_MODEL, VOICEPACKS
73
 
 
 
 
74
  try:
75
  # Check if Kokoro already exists
76
  if not os.path.exists(KOKORO_PATH):
77
  print("Cloning Kokoro-82M repository...")
78
  # Install git-lfs if not present (might need sudo/apt)
79
  try:
80
- subprocess.run(['git', 'lfs', 'install'], check=True, capture_output=True)
 
81
  except (FileNotFoundError, subprocess.CalledProcessError) as lfs_err:
82
- print(f"Warning: git-lfs might not be installed or failed: {lfs_err}. Cloning might be slow or incomplete.")
83
 
84
- clone_cmd = ['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M']
85
  result = subprocess.run(clone_cmd, check=True, capture_output=True, text=True)
86
  print("Kokoro cloned successfully.")
87
- print(result.stdout)
88
- # Optionally pull LFS files if needed (sometimes clone doesn't get them all)
89
- # subprocess.run(['git', 'lfs', 'pull'], cwd=KOKORO_PATH, check=True)
 
 
 
 
 
 
90
 
91
  else:
92
- print("Kokoro-82M directory already exists.")
93
 
94
  # Install espeak (essential for phonemization)
95
  print("Attempting to install espeak-ng or espeak...")
 
 
 
 
 
 
 
 
 
96
  try:
97
- # Try installing espeak-ng first (often preferred)
98
- subprocess.run(['sudo', 'apt-get', 'update'], check=True, capture_output=True)
99
- subprocess.run(['sudo', 'apt-get', 'install', '-y', 'espeak-ng'], check=True, capture_output=True)
 
100
  print("espeak-ng installed successfully.")
101
- except (FileNotFoundError, subprocess.CalledProcessError):
102
- print("espeak-ng installation failed, trying espeak...")
103
  try:
104
- # Fallback to espeak
105
- subprocess.run(['sudo', 'apt-get', 'install', '-y', 'espeak'], check=True, capture_output=True)
106
  print("espeak installed successfully.")
107
- except (FileNotFoundError, subprocess.CalledProcessError) as espeak_err:
108
- print(f"Warning: Could not install espeak-ng or espeak: {espeak_err}. TTS functionality will be disabled.")
109
  return # Cannot proceed without espeak
110
 
111
  # Set up Kokoro TTS
@@ -117,293 +146,321 @@ def setup_tts():
117
  from models import build_model
118
  from kokoro import generate as generate_tts_internal # Avoid name clash
119
 
120
- # Make these functions accessible globally if needed, but better to keep scoped
121
  globals()['build_model'] = build_model
122
  globals()['generate_tts_internal'] = generate_tts_internal
123
 
124
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
125
  print(f"Loading TTS model onto device: {device}")
126
- # Ensure model path is correct
127
  model_file = os.path.join(KOKORO_PATH, 'kokoro-v0_19.pth')
 
128
  if not os.path.exists(model_file):
129
- print(f"Error: TTS model file not found at {model_file}")
130
- # Attempt to pull LFS files again
131
  try:
132
- print("Attempting git lfs pull...")
133
- subprocess.run(['git', 'lfs', 'pull'], cwd=KOKORO_PATH, check=True, capture_output=True)
134
  if not os.path.exists(model_file):
135
- print(f"Error: TTS model file STILL not found at {model_file} after lfs pull.")
136
  return
137
  except Exception as lfs_pull_err:
138
- print(f"Error during git lfs pull: {lfs_pull_err}")
139
  return
140
 
141
  TTS_MODEL = build_model(model_file, device)
 
142
 
143
- # Preload default voice
144
- default_voice_id = 'af'
145
- voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{default_voice_id}.pt')
146
- if os.path.exists(voice_file_path):
147
- print(f"Loading default voice: {default_voice_id}")
148
- VOICEPACKS[default_voice_id] = torch.load(voice_file_path,
149
- map_location=device) # Removed weights_only=True
150
- else:
151
- print(f"Warning: Default voice file {voice_file_path} not found.")
152
-
153
-
154
- # Preload other common voices to reduce latency
155
  for voice_name, voice_id in VOICE_CHOICES.items():
156
- if voice_id != default_voice_id: # Avoid reloading default
157
- voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
158
- if os.path.exists(voice_file_path):
159
- try:
160
- print(f"Preloading voice: {voice_id}")
161
- VOICEPACKS[voice_id] = torch.load(voice_file_path,
162
- map_location=device) # Removed weights_only=True
163
- except Exception as e:
164
- print(f"Warning: Could not preload voice {voice_id}: {str(e)}")
165
- else:
166
- print(f"Info: Voice file {voice_file_path} for '{voice_name}' not found, will skip preloading.")
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  TTS_ENABLED = True
169
- print("TTS setup completed successfully")
 
170
  except ImportError as ie:
171
- print(f"Error importing Kokoro modules: {ie}. Check if Kokoro-82M is correctly cloned and in sys.path.")
172
  except Exception as model_load_err:
173
- print(f"Error loading TTS model or voices: {model_load_err}")
 
174
 
175
  else:
176
- print(f"Warning: {KOKORO_PATH} directory not found after clone attempt. TTS disabled.")
177
  except subprocess.CalledProcessError as spe:
178
- print(f"Warning: A subprocess command failed during TTS setup: {spe}")
179
  print(f"Command: {' '.join(spe.cmd)}")
180
- print(f"Stderr: {spe.stderr}")
181
- print("TTS may be disabled.")
182
  except Exception as e:
183
- print(f"Warning: An unexpected error occurred during TTS setup: {str(e)}")
 
184
  TTS_ENABLED = False
185
 
186
  # Start TTS setup in a separate thread
 
187
  print("Starting TTS setup in background thread...")
188
  tts_thread = threading.Thread(target=setup_tts, daemon=True)
189
  tts_thread.start()
190
 
191
  # --- Search and Generation Functions ---
 
192
  @lru_cache(maxsize=128)
193
  def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
194
- """Get web search results using DuckDuckGo with caching for improved performance"""
195
- print(f"Performing web search for: '{query}'")
196
  try:
 
197
  with DDGS() as ddgs:
198
- # Using safe='off' potentially gives more results but use cautiously
199
- results = list(ddgs.text(query, max_results=max_results, safesearch='moderate'))
200
- print(f"Found {len(results)} results.")
201
  formatted_results = []
202
- for result in results:
203
  formatted_results.append({
204
- "title": result.get("title", "No Title"),
 
205
  "snippet": result.get("body", "No Snippet Available"),
206
  "url": result.get("href", "#"),
207
- # Attempt to extract date - DDGS doesn't reliably provide it
208
- # "date": result.get("published", "") # Placeholder
209
  })
210
  return formatted_results
211
  except Exception as e:
212
- print(f"Error in web search: {e}")
 
213
  return []
214
 
215
  def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
216
- """Format the prompt with web context"""
217
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
218
- context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for i, res in enumerate(context)]) # No need for index here
 
 
 
 
 
 
 
 
 
 
219
  prompt = f"""You are a helpful AI assistant. Your task is to answer the user's query based *only* on the provided web search context.
220
- Do not add information not present in the context.
221
- Cite the sources used in your answer using bracket notation, e.g., [Source Title]. Use the titles from the context.
222
- If the context does not contain relevant information to answer the query, state that clearly.
 
 
 
 
 
223
  Current Time: {current_time}
224
 
225
  Web Context:
226
- {context_lines if context else "No web context available."}
 
 
227
 
228
  User Query: {query}
229
 
230
  Answer:"""
231
- # print(f"Formatted Prompt:\n{prompt}") # Debugging
232
  return prompt
233
 
234
  def format_sources(web_results: List[Dict[str, str]]) -> str:
235
- """Format sources with more details"""
236
  if not web_results:
237
- return "<div class='no-sources'>No sources found for the query.</div>"
238
 
239
  sources_html = "<div class='sources-container'>"
240
- for i, res in enumerate(web_results, 1):
241
  title = res.get("title", "Source")
242
  url = res.get("url", "#")
243
- # date = f"<span class='source-date'>{res['date']}</span>" if res.get('date') else "" # DDG date is unreliable
244
- snippet = res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else "")
 
 
 
245
  sources_html += f"""
246
  <div class='source-item'>
247
- <div class='source-number'>[{i}]</div>
248
  <div class='source-content'>
249
- <a href="{url}" target="_blank" class='source-title' title="{url}">{title}</a>
250
- <div class='source-snippet'>{snippet}</div>
251
  </div>
252
  </div>
253
  """
254
  sources_html += "</div>"
255
  return sources_html
256
 
257
- # Use a ThreadPoolExecutor for potentially blocking I/O or CPU-bound tasks
258
- # Keep GPU tasks separate if possible, or ensure thread safety if sharing GPU resources
259
- executor = ThreadPoolExecutor(max_workers=4)
260
 
261
- @spaces.GPU(duration=GPU_DURATION, cancellable=True)
262
  async def generate_answer(prompt: str) -> str:
263
- """Generate answer using the DeepSeek model with optimized settings (Async Wrapper)"""
264
- print("Generating answer...")
 
265
  try:
 
266
  inputs = tokenizer(
267
  prompt,
268
  return_tensors="pt",
269
  padding=True,
270
  truncation=True,
271
- max_length=1024, # Increased context length
272
  return_attention_mask=True
273
  ).to(model.device)
274
 
275
- # Ensure generation runs on the correct device
276
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available() and torch_dtype == torch.float16):
277
- outputs = await asyncio.to_thread( # Use asyncio.to_thread for potentially blocking calls
 
278
  model.generate,
279
- inputs.input_ids,
280
  attention_mask=inputs.attention_mask,
281
  max_new_tokens=MAX_NEW_TOKENS,
282
  temperature=TEMPERATURE,
283
  top_p=TOP_P,
284
  pad_token_id=tokenizer.eos_token_id,
 
285
  do_sample=True,
286
- early_stopping=True,
287
  num_return_sequences=1
288
  )
289
 
290
- # Decode output
 
 
 
 
291
  full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
292
- # Extract only the generated part after "Answer:"
293
- answer_part = full_output.split("Answer:")[-1].strip()
294
- print(f"Generated Answer Raw Length: {len(outputs[0])}, Decoded Answer Part Length: {len(answer_part)}")
295
- if not answer_part: # Handle cases where split might fail or answer is empty
296
- print("Warning: Could not extract answer after 'Answer:'. Returning full output.")
297
- return full_output # Fallback
298
- return answer_part
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  except Exception as e:
300
- print(f"Error during answer generation: {e}")
301
- # You might want to return a specific error message here
302
  return f"Error generating answer: {str(e)}"
303
 
304
- # Ensure this function runs potentially long tasks in a thread using the executor
305
- # @spaces.GPU(duration=GPU_DURATION, cancellable=True) # Keep GPU decorator if TTS uses GPU heavily
306
  async def generate_speech(text: str, voice_id: str = 'af') -> Tuple[int, np.ndarray] | None:
307
  """Generate speech from text using Kokoro TTS model (Async Wrapper)."""
308
  global TTS_MODEL, TTS_ENABLED, VOICEPACKS
309
- print(f"Attempting to generate speech for text (length {len(text)}) with voice '{voice_id}'")
310
 
311
  if not TTS_ENABLED or TTS_MODEL is None:
312
- print("TTS is not enabled or model not loaded.")
313
  return None
314
  if 'generate_tts_internal' not in globals():
315
- print("TTS generation function 'generate_tts_internal' not found.")
 
 
 
316
  return None
317
 
318
- try:
319
- device = TTS_MODEL.device # Get device from the loaded TTS model
320
 
321
- # Load voicepack if needed (handle potential errors)
322
- if voice_id not in VOICEPACKS:
323
- voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
324
- if os.path.exists(voice_file_path):
325
- print(f"Loading voice '{voice_id}' on demand...")
326
- try:
327
- VOICEPACKS[voice_id] = await asyncio.to_thread(
328
- torch.load, voice_file_path, map_location=device # Removed weights_only=True
329
- )
330
- except Exception as load_err:
331
- print(f"Error loading voicepack {voice_id}: {load_err}. Falling back to default 'af'.")
332
- voice_id = 'af' # Fallback to default
333
- # Ensure default is loaded if fallback occurs
334
- if 'af' not in VOICEPACKS:
335
- default_voice_file = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
336
- if os.path.exists(default_voice_file):
337
- VOICEPACKS['af'] = await asyncio.to_thread(
338
- torch.load, default_voice_file, map_location=device
339
- )
340
- else:
341
- print("Default voice 'af' also not found. Cannot generate audio.")
342
- return None
343
- else:
344
- print(f"Voicepack {voice_id}.pt not found. Falling back to default 'af'.")
345
- voice_id = 'af' # Fallback to default
346
- if 'af' not in VOICEPACKS: # Check again if default is needed now
347
- default_voice_file = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
348
- if os.path.exists(default_voice_file):
349
- VOICEPACKS['af'] = await asyncio.to_thread(
350
- torch.load, default_voice_file, map_location=device
351
- )
352
- else:
353
- print("Default voice 'af' also not found. Cannot generate audio.")
354
- return None
355
 
 
356
  if voice_id not in VOICEPACKS:
357
- print(f"Error: Voice '{voice_id}' could not be loaded.")
358
- return None
 
 
 
 
 
359
 
360
  # Clean the text (simple cleaning)
361
- clean_text = ' '.join(text.split()) # Remove extra whitespace
362
- clean_text = clean_text.replace('*', '').replace('[', '').replace(']', '') # Remove markdown chars
 
 
 
 
363
 
364
- # Ensure text isn't empty
365
  if not clean_text.strip():
366
- print("Warning: Empty text provided for TTS.")
367
  return None
368
 
369
- # Limit text length
370
  if len(clean_text) > MAX_TTS_CHARS:
371
- print(f"Warning: Text too long ({len(clean_text)} chars), truncating to {MAX_TTS_CHARS}.")
372
- # Simple truncation, could be smarter (split by sentence)
373
  clean_text = clean_text[:MAX_TTS_CHARS]
374
- last_space = clean_text.rfind(' ')
375
- if last_space != -1:
376
- clean_text = clean_text[:last_space] + "..." # Truncate at last space
 
 
377
 
378
- # Run the potentially blocking TTS generation in a thread
379
- print(f"Generating audio for: '{clean_text[:100]}...'")
380
  gen_func = globals()['generate_tts_internal']
381
- loop = asyncio.get_event_loop()
382
- audio_data, _ = await loop.run_in_executor(
383
- executor, # Use the thread pool executor
 
384
  gen_func,
385
  TTS_MODEL,
386
  clean_text,
387
  VOICEPACKS[voice_id],
388
- 'a' # Language code (assuming 'a' is appropriate)
389
  )
390
 
391
  if isinstance(audio_data, torch.Tensor):
392
  # Move tensor to CPU before converting to numpy if it's not already
393
- audio_np = audio_data.cpu().numpy()
394
  elif isinstance(audio_data, np.ndarray):
395
  audio_np = audio_data
396
  else:
397
- print("Warning: Unexpected audio data type from TTS.")
398
  return None
399
 
400
- print(f"Audio generated successfully, shape: {audio_np.shape}")
 
 
 
 
401
  return (TTS_SAMPLE_RATE, audio_np)
402
 
403
  except Exception as e:
404
- import traceback
405
- print(f"Error generating speech: {str(e)}")
406
- print(traceback.format_exc()) # Print full traceback for debugging
407
  return None
408
 
409
  # Helper to get voice ID from display name
@@ -411,22 +468,29 @@ def get_voice_id(voice_display_name: str) -> str:
411
  """Maps the user-friendly voice name to the internal voice ID."""
412
  return VOICE_CHOICES.get(voice_display_name, 'af') # Default to 'af' if not found
413
 
414
- # --- Main Processing Logic (Async) ---
 
 
415
  async def process_query_async(query: str, history: List[List[str]], selected_voice_display_name: str):
416
  """Asynchronously process user query: search -> generate answer -> generate speech"""
417
- if not query:
 
 
 
 
418
  yield (
419
- "Please enter a query.", "", "Search", history, None
420
  )
421
  return
422
 
423
  if history is None: history = []
424
- current_history = history + [[query, "*Searching...*"]]
 
425
 
426
  # 1. Initial state: Searching
427
  yield (
428
- "*Searching & Thinking...*",
429
- "<div class='searching'>Searching the web...</div>",
430
  gr.Button(value="Searching...", interactive=False), # Disable button
431
  current_history,
432
  None
@@ -438,26 +502,27 @@ async def process_query_async(query: str, history: List[List[str]], selected_voi
438
  sources_html = format_sources(web_results)
439
 
440
  # Update state: Analyzing results
441
- current_history[-1][1] = "*Analyzing search results...*"
442
  yield (
443
- "*Analyzing search results...*",
444
  sources_html,
445
  gr.Button(value="Generating...", interactive=False),
446
- current_history,
447
  None
448
  )
449
 
450
  # 3. Generate Answer (non-blocking, potentially on GPU)
451
  prompt = format_prompt(query, web_results)
452
- final_answer = await generate_answer(prompt) # Already async
453
 
454
- # Update state: Answer generated
455
  current_history[-1][1] = final_answer
 
 
456
  yield (
457
  final_answer,
458
  sources_html,
459
  gr.Button(value="Audio...", interactive=False),
460
- current_history,
461
  None
462
  )
463
 
@@ -465,41 +530,54 @@ async def process_query_async(query: str, history: List[List[str]], selected_voi
465
  audio = None
466
  tts_message = ""
467
  if not tts_thread.is_alive() and not TTS_ENABLED:
468
- tts_message = "\n\n*(TTS setup failed or is disabled)*"
 
469
  elif tts_thread.is_alive():
470
- tts_message = "\n\n*(TTS is still initializing, audio may be delayed)*"
 
471
  elif TTS_ENABLED:
472
  voice_id = get_voice_id(selected_voice_display_name)
473
- audio = await generate_speech(final_answer, voice_id) # Already async
474
- if audio is None:
475
- tts_message = f"\n\n*(Audio generation failed for voice '{voice_id}')*"
 
 
 
 
 
 
 
 
 
476
 
477
  # 5. Final state: Show everything
 
478
  yield (
479
  final_answer + tts_message,
480
  sources_html,
481
  gr.Button(value="Search", interactive=True), # Re-enable button
482
- current_history,
483
  audio
484
  )
485
 
486
 
487
  # --- Gradio Interface ---
 
488
  css = """
489
- /* ... [Your existing CSS remains unchanged] ... */
490
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
491
  #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
492
  #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
493
  #header h3 { color: #a8a9ab; }
494
  .search-container { background: #ffffff; border: 1px solid #e0e0e0; border-radius: 12px; box-shadow: 0 4px 16px rgba(0,0,0,0.05); padding: 1.5rem; margin-bottom: 1.5rem; }
495
- .search-box { padding: 0; margin-bottom: 1rem; }
496
- .search-box .gradio-textbox { border-radius: 8px 0 0 8px !important; } /* Style textbox specifically */
497
- .search-box .gradio-dropdown { border-radius: 0 !important; margin-left: -1px; margin-right: -1px;} /* Style dropdown */
498
- .search-box .gradio-button { border-radius: 0 8px 8px 0 !important; } /* Style button */
499
- .search-box input[type="text"] { background: #f7f7f8 !important; border: 1px solid #d1d5db !important; color: #1f2937 !important; transition: all 0.3s ease; height: 42px !important; }
500
- .search-box input[type="text"]:focus { border-color: #2563eb !important; box-shadow: 0 0 0 2px rgba(37, 99, 235, 0.2) !important; background: white !important; }
501
  .search-box input[type="text"]::placeholder { color: #9ca3af !important; }
502
- .search-box button { background: #2563eb !important; border: none !important; color: white !important; box-shadow: 0 1px 2px rgba(0,0,0,0.05) !important; transition: all 0.3s ease !important; height: 44px !important; }
503
  .search-box button:hover { background: #1d4ed8 !important; }
504
  .search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
505
  .results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
@@ -513,8 +591,8 @@ css = """
513
  .source-item:last-child { border-bottom: none; }
514
  /* .source-item:hover { background-color: #f9fafb; } */
515
  .source-number { font-weight: bold; margin-right: 12px; color: #6b7280; width: 20px; text-align: right; flex-shrink: 0;}
516
- .source-content { flex: 1; }
517
- .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; }
518
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
519
  .source-date { color: #6b7280; font-size: 0.8em; margin-left: 8px; }
520
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
@@ -542,10 +620,10 @@ css = """
542
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
543
  .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; }
544
  .accordion > .label-wrap { padding: 10px 15px !important; } /* Style accordion header */
545
- .voice-selector { margin: 0; padding: 0; }
546
- .voice-selector div[data-testid="dropdown"] { /* Target the specific dropdown container */ height: 44px !important; }
547
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
548
- .voice-selector select:focus { border-color: #2563eb !important; box-shadow: none !important; }
549
  .audio-player { margin-top: 1rem; background: #f9fafb !important; border-radius: 8px !important; padding: 0.5rem !important; border: 1px solid #e5e7eb;}
550
  .audio-player audio { width: 100% !important; }
551
  .searching, .error { padding: 1rem; border-radius: 8px; text-align: center; margin: 1rem 0; border: 1px dashed; }
@@ -553,7 +631,8 @@ css = """
553
  .error { background: #fef2f2; color: #ef4444; border-color: #fecaca; }
554
  .no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
555
  @keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
556
- .searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; } /* Add span for animation */
 
557
  .dark .gradio-container { background-color: #111827 !important; }
558
  .dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
559
  .dark #header h3 { color: #9ca3af; }
@@ -575,7 +654,7 @@ css = """
575
  .dark .source-title { color: #60a5fa; }
576
  .dark .source-title:hover { color: #93c5fd; }
577
  .dark .source-snippet { color: #d1d5db; }
578
- .dark .chat-history { background: #374151; border-color: #4b5563; scrollbar-color: #4b5563 #374151; }
579
  .dark .chat-history::-webkit-scrollbar-track { background: #374151; }
580
  .dark .chat-history::-webkit-scrollbar-thumb { background-color: #4b5563; }
581
  .dark .examples-container { background: #374151; border-color: #4b5563; }
@@ -592,112 +671,159 @@ css = """
592
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
593
  .dark .markdown-content th { background: #374151 !important; }
594
  .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; }
 
595
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
596
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
597
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
 
 
 
 
598
  .dark .searching { background: #1e3a8a; color: #93c5fd; border-color: #3b82f6; }
599
  .dark .error { background: #7f1d1d; color: #fca5a5; border-color: #ef4444; }
600
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
601
-
602
  """
603
 
604
  with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
 
605
  chat_history = gr.State([])
606
 
607
- with gr.Column(): # Main container
 
608
  with gr.Column(elem_id="header"):
609
  gr.Markdown("# 🔍 AI Search Assistant")
610
  gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
611
 
 
612
  with gr.Column(elem_classes="search-container"):
613
- with gr.Row(elem_classes="search-box", equal_height=True):
614
  search_input = gr.Textbox(
615
  label="",
616
  placeholder="Ask anything...",
617
- scale=5,
618
- container=False, # Important for direct styling
619
  elem_classes="gradio-textbox"
620
  )
621
  voice_select = gr.Dropdown(
622
  choices=list(VOICE_CHOICES.keys()),
623
- value=list(VOICE_CHOICES.keys())[0],
624
- label="", # No label needed here
625
- scale=2,
 
626
  container=False, # Important
627
  elem_classes="voice-selector gradio-dropdown"
628
  )
629
  search_btn = gr.Button(
630
  "Search",
631
  variant="primary",
632
- scale=1,
 
633
  elem_classes="gradio-button"
634
  )
635
 
 
636
  with gr.Row(elem_classes="results-container", equal_height=False):
637
- with gr.Column(scale=3): # Wider column for answer + history
 
638
  with gr.Column(elem_classes="answer-box"):
639
- answer_output = gr.Markdown(elem_classes="markdown-content", value="*Your answer will appear here...*")
640
- # Audio player below the answer
641
- audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player", type="numpy") # Expect numpy array
 
 
 
 
 
 
642
 
643
  with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
644
- chat_history_display = gr.Chatbot(elem_classes="chat-history", label="History", height=300)
645
-
646
- with gr.Column(scale=2): # Narrower column for sources
 
 
 
 
 
 
647
  with gr.Column(elem_classes="sources-box"):
648
  gr.Markdown("### Sources")
649
  sources_output = gr.HTML(value="<div class='no-sources'>Sources will appear here after searching.</div>")
650
 
 
651
  with gr.Row(elem_classes="examples-container"):
652
  gr.Examples(
653
  examples=[
654
  "Latest news about renewable energy",
655
  "Explain the concept of Large Language Models (LLMs)",
656
  "What are the symptoms and prevention tips for the flu?",
657
- "Compare Python and JavaScript for web development"
 
658
  ],
659
- inputs=search_input,
660
  label="Try these examples:",
661
  elem_classes="gradio-examples" # Add class for potential styling
662
  )
663
 
664
  # --- Event Handling ---
665
- # Use the async function for processing
666
  async def handle_interaction(query, history, voice_display_name):
667
- """Wrapper to handle the async generator from process_query_async"""
 
 
 
 
 
 
 
 
668
  try:
669
- async for update in process_query_async(query, history, voice_display_name):
670
- # Ensure the button state is updated correctly
671
- ans_out, src_out, btn_state, hist_display, aud_out = update
672
- yield ans_out, src_out, btn_state, hist_display, aud_out
 
 
 
 
 
 
 
 
673
  except Exception as e:
674
- print(f"Error in handle_interaction: {e}")
675
- import traceback
676
- traceback.print_exc()
677
  error_message = f"An unexpected error occurred: {e}"
678
  # Provide a final error state update
 
679
  yield (
680
  error_message,
681
- "<div class='error'>Error processing request.</div>",
682
  gr.Button(value="Search", interactive=True), # Re-enable button on error
683
- history + [[query, f"*Error: {error_message}*"]],
684
  None
685
  )
686
 
 
 
 
687
 
688
- # Corrected event listeners: Pass the voice_select component directly
689
  search_btn.click(
690
  fn=handle_interaction,
691
- inputs=[search_input, chat_history, voice_select], # Pass voice_select component
692
- outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
693
  )
694
 
695
  search_input.submit(
696
  fn=handle_interaction,
697
- inputs=[search_input, chat_history, voice_select], # Pass voice_select component
698
- outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
699
  )
700
 
701
  if __name__ == "__main__":
702
- # Launch the app
703
- demo.queue(max_size=20).launch(debug=True, share=True) # Enable debug for more logs
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import spaces # Keep for potential future use or other decorators
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
 
14
  import threading
15
  from concurrent.futures import ThreadPoolExecutor
16
  import warnings
17
+ import traceback # For detailed error logging
18
 
19
  # Suppress specific warnings if needed (optional)
20
  warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
21
+ # Suppress another common warning with torch.compile backend
22
+ # warnings.filterwarnings("ignore", message="Backend 'inductor' is not available.")
23
 
24
  # --- Configuration ---
25
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
26
  MAX_SEARCH_RESULTS = 5
27
  TTS_SAMPLE_RATE = 24000
28
+ MAX_TTS_CHARS = 1000 # Max characters for a single TTS chunk
29
+ # GPU_DURATION = 60 # Informational only now, decorator is removed
30
+ MAX_NEW_TOKENS = 300 # Increased slightly
31
  TEMPERATURE = 0.7
32
  TOP_P = 0.95
33
+ KOKORO_PATH = 'Kokoro-82M' # Path to TTS model directory
34
 
35
  # --- Initialization ---
36
+ # Use a ThreadPoolExecutor for potentially blocking I/O or CPU-bound tasks
37
+ executor = ThreadPoolExecutor(max_workers=4)
38
+
39
  # Initialize model and tokenizer with better error handling
40
  try:
41
  print("Loading tokenizer...")
 
46
  # Determine device map based on CUDA availability
47
  device_map = "auto" if torch.cuda.is_available() else {"": "cpu"}
48
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Use float32 on CPU
49
+ print(f"Attempting to load model with device_map='{device_map}' and dtype={torch_dtype}")
50
 
51
  model = AutoModelForCausalLM.from_pretrained(
52
  MODEL_NAME,
53
  device_map=device_map,
54
+ # offload_folder="offload", # Enable if needed for large models and disk space is available
55
+ low_cpu_mem_usage=True, # Important for faster loading
56
+ torch_dtype=torch_dtype,
57
+ # attn_implementation="flash_attention_2" # Optional: requires flash-attn installed, use if available for speedup on compatible GPUs
58
  )
59
+ print(f"Model loaded successfully. Device map: {model.hf_device_map}")
60
+ # Ensure model is in evaluation mode
61
+ model.eval()
62
+
63
  except Exception as e:
64
+ print(f"FATAL: Error initializing LLM model: {str(e)}")
65
+ print(traceback.format_exc())
66
+ raise # Stop execution if model loading fails
 
67
 
68
  # --- TTS Setup ---
69
  VOICE_CHOICES = {
 
75
  TTS_ENABLED = False
76
  TTS_MODEL = None
77
  VOICEPACKS = {} # Cache voice packs
 
78
 
79
  # Initialize Kokoro TTS in a separate thread to avoid blocking startup
80
  def setup_tts():
81
  global TTS_ENABLED, TTS_MODEL, VOICEPACKS
82
 
83
+ # Check privileges for apt-get
84
+ can_sudo = shutil.which('sudo') is not None
85
+
86
  try:
87
  # Check if Kokoro already exists
88
  if not os.path.exists(KOKORO_PATH):
89
  print("Cloning Kokoro-82M repository...")
90
  # Install git-lfs if not present (might need sudo/apt)
91
  try:
92
+ lfs_install_cmd = ['git', 'lfs', 'install']
93
+ subprocess.run(lfs_install_cmd, check=True, capture_output=True, text=True)
94
  except (FileNotFoundError, subprocess.CalledProcessError) as lfs_err:
95
+ print(f"Warning: git-lfs command failed: {lfs_err}. Cloning might be slow or incomplete.")
96
 
97
+ clone_cmd = ['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M', KOKORO_PATH]
98
  result = subprocess.run(clone_cmd, check=True, capture_output=True, text=True)
99
  print("Kokoro cloned successfully.")
100
+ # print(result.stdout) # Can be verbose
101
+ # Optionally pull LFS files again (sometimes clone doesn't get them all)
102
+ try:
103
+ print("Running git lfs pull...")
104
+ lfs_pull_cmd = ['git', 'lfs', 'pull']
105
+ subprocess.run(lfs_pull_cmd, cwd=KOKORO_PATH, check=True, capture_output=True, text=True)
106
+ print("git lfs pull completed.")
107
+ except (FileNotFoundError, subprocess.CalledProcessError) as lfs_pull_err:
108
+ print(f"Warning: git lfs pull failed: {lfs_pull_err}")
109
 
110
  else:
111
+ print(f"{KOKORO_PATH} directory already exists.")
112
 
113
  # Install espeak (essential for phonemization)
114
  print("Attempting to install espeak-ng or espeak...")
115
+ apt_update_cmd = ['apt-get', 'update', '-qq']
116
+ install_cmd_ng = ['apt-get', 'install', '-y', '-qq', 'espeak-ng']
117
+ install_cmd_legacy = ['apt-get', 'install', '-y', '-qq', 'espeak']
118
+
119
+ if can_sudo:
120
+ apt_update_cmd.insert(0, 'sudo')
121
+ install_cmd_ng.insert(0, 'sudo')
122
+ install_cmd_legacy.insert(0, 'sudo')
123
+
124
  try:
125
+ print(f"Running: {' '.join(apt_update_cmd)}")
126
+ subprocess.run(apt_update_cmd, check=True, capture_output=True)
127
+ print(f"Running: {' '.join(install_cmd_ng)}")
128
+ subprocess.run(install_cmd_ng, check=True, capture_output=True)
129
  print("espeak-ng installed successfully.")
130
+ except (FileNotFoundError, subprocess.CalledProcessError) as ng_err:
131
+ print(f"espeak-ng installation failed ({ng_err}), trying espeak...")
132
  try:
133
+ print(f"Running: {' '.join(install_cmd_legacy)}")
134
+ subprocess.run(install_cmd_legacy, check=True, capture_output=True)
135
  print("espeak installed successfully.")
136
+ except (FileNotFoundError, subprocess.CalledProcessError) as legacy_err:
137
+ print(f"ERROR: Could not install espeak-ng or espeak: {legacy_err}. TTS functionality will be disabled.")
138
  return # Cannot proceed without espeak
139
 
140
  # Set up Kokoro TTS
 
146
  from models import build_model
147
  from kokoro import generate as generate_tts_internal # Avoid name clash
148
 
149
+ # Make these functions accessible globally if needed
150
  globals()['build_model'] = build_model
151
  globals()['generate_tts_internal'] = generate_tts_internal
152
 
153
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
154
  print(f"Loading TTS model onto device: {device}")
 
155
  model_file = os.path.join(KOKORO_PATH, 'kokoro-v0_19.pth')
156
+
157
  if not os.path.exists(model_file):
158
+ print(f"Error: TTS model file not found at {model_file}. Attempting git lfs pull again...")
 
159
  try:
160
+ lfs_pull_cmd = ['git', 'lfs', 'pull']
161
+ subprocess.run(lfs_pull_cmd, cwd=KOKORO_PATH, check=True, capture_output=True, text=True)
162
  if not os.path.exists(model_file):
163
+ print(f"ERROR: TTS model file STILL not found at {model_file} after lfs pull. TTS disabled.")
164
  return
165
  except Exception as lfs_pull_err:
166
+ print(f"Error during git lfs pull: {lfs_pull_err}. TTS disabled.")
167
  return
168
 
169
  TTS_MODEL = build_model(model_file, device)
170
+ print("TTS model loaded.")
171
 
172
+ # Preload voices
 
 
 
 
 
 
 
 
 
 
 
173
  for voice_name, voice_id in VOICE_CHOICES.items():
174
+ voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
175
+ if os.path.exists(voice_file_path):
176
+ try:
177
+ print(f"Loading voice: {voice_id} ({voice_name})")
178
+ # Load using torch.load, map_location handles device placement
179
+ VOICEPACKS[voice_id] = torch.load(voice_file_path, map_location=device)
180
+ except Exception as e:
181
+ print(f"Warning: Could not load voice {voice_id}: {str(e)}")
182
+ else:
183
+ print(f"Info: Voice file {voice_file_path} for '{voice_name}' not found, skipping.")
184
+
185
+ if not VOICEPACKS:
186
+ print("ERROR: No voicepacks could be loaded. TTS disabled.")
187
+ return
188
+
189
+ # Ensure default 'af' is loaded if possible, even if not explicitly in choices sometimes
190
+ if 'af' not in VOICEPACKS:
191
+ voice_file_path = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
192
+ if os.path.exists(voice_file_path):
193
+ try:
194
+ print(f"Loading fallback default voice: af")
195
+ VOICEPACKS['af'] = torch.load(voice_file_path, map_location=device)
196
+ except Exception as e:
197
+ print(f"Warning: Could not load fallback default voice 'af': {str(e)}")
198
 
199
  TTS_ENABLED = True
200
+ print("TTS setup completed successfully.")
201
+
202
  except ImportError as ie:
203
+ print(f"ERROR: Importing Kokoro modules failed: {ie}. Check if {KOKORO_PATH} exists and dependencies are met.")
204
  except Exception as model_load_err:
205
+ print(f"ERROR: Loading TTS model or voices failed: {model_load_err}")
206
+ print(traceback.format_exc())
207
 
208
  else:
209
+ print(f"ERROR: {KOKORO_PATH} directory not found. TTS disabled.")
210
  except subprocess.CalledProcessError as spe:
211
+ print(f"ERROR: A subprocess command failed during TTS setup: {spe}")
212
  print(f"Command: {' '.join(spe.cmd)}")
213
+ if spe.stderr: print(f"Stderr: {spe.stderr.strip()}")
214
+ print("TTS setup failed.")
215
  except Exception as e:
216
+ print(f"ERROR: An unexpected error occurred during TTS setup: {str(e)}")
217
+ print(traceback.format_exc())
218
  TTS_ENABLED = False
219
 
220
  # Start TTS setup in a separate thread
221
+ import shutil
222
  print("Starting TTS setup in background thread...")
223
  tts_thread = threading.Thread(target=setup_tts, daemon=True)
224
  tts_thread.start()
225
 
226
  # --- Search and Generation Functions ---
227
+
228
  @lru_cache(maxsize=128)
229
  def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
230
+ """Get web search results using DuckDuckGo with caching."""
231
+ print(f"[Web Search] Searching for: '{query}' (max_results={max_results})")
232
  try:
233
+ # Use DDGS context manager for cleanup
234
  with DDGS() as ddgs:
235
+ # Fetch results using ddgs.text()
236
+ results = list(ddgs.text(query, max_results=max_results, safesearch='moderate', timelimit='y')) # Limit to past year
237
+ print(f"[Web Search] Found {len(results)} results.")
238
  formatted_results = []
239
+ for i, result in enumerate(results):
240
  formatted_results.append({
241
+ "id": i + 1, # Add simple ID for citation
242
+ "title": result.get("title", "No Title Available"),
243
  "snippet": result.get("body", "No Snippet Available"),
244
  "url": result.get("href", "#"),
 
 
245
  })
246
  return formatted_results
247
  except Exception as e:
248
+ print(f"[Web Search] Error: {e}")
249
+ print(traceback.format_exc())
250
  return []
251
 
252
  def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
253
+ """Format the prompt with web context for the LLM."""
254
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
255
+
256
+ # Format context with IDs for citation
257
+ context_lines = []
258
+ if context:
259
+ for res in context:
260
+ context_lines.append(f"[{res['id']}] {res['title']}\n{res['snippet']}")
261
+ context_str = "\n\n".join(context_lines)
262
+ else:
263
+ context_str = "No web context available."
264
+
265
+ # Clear instructions for the model
266
  prompt = f"""You are a helpful AI assistant. Your task is to answer the user's query based *only* on the provided web search context.
267
+ Follow these instructions carefully:
268
+ 1. Synthesize the information from the context to provide a comprehensive answer.
269
+ 2. Cite the sources used in your answer using bracket notation with the source ID, like [1], [2], etc.
270
+ 3. If multiple sources support a point, you can cite them together, e.g., [1][3].
271
+ 4. Do *not* add information that is not present in the context.
272
+ 5. If the context does not contain relevant information to answer the query, clearly state that you cannot answer based on the provided context.
273
+ 6. Format the answer clearly using markdown.
274
+
275
  Current Time: {current_time}
276
 
277
  Web Context:
278
+ ---
279
+ {context_str}
280
+ ---
281
 
282
  User Query: {query}
283
 
284
  Answer:"""
285
+ # print(f"--- Formatted Prompt ---\n{prompt[:1000]}...\n--- End Prompt ---") # Debugging: Print start of prompt
286
  return prompt
287
 
288
  def format_sources(web_results: List[Dict[str, str]]) -> str:
289
+ """Format sources into HTML for display."""
290
  if not web_results:
291
+ return "<div class='no-sources'>No sources found for this query.</div>"
292
 
293
  sources_html = "<div class='sources-container'>"
294
+ for res in web_results:
295
  title = res.get("title", "Source")
296
  url = res.get("url", "#")
297
+ snippet = res.get("snippet", "")
298
+ # Basic HTML escaping for snippet and title
299
+ title_safe = gr. gradio.utils.escape_html(title)
300
+ snippet_safe = gr. gradio.utils.escape_html(snippet[:150] + ("..." if len(snippet) > 150 else ""))
301
+
302
  sources_html += f"""
303
  <div class='source-item'>
304
+ <div class='source-number'>[{res['id']}]</div>
305
  <div class='source-content'>
306
+ <a href="{url}" target="_blank" class='source-title' title="{url}">{title_safe}</a>
307
+ <div class='source-snippet'>{snippet_safe}</div>
308
  </div>
309
  </div>
310
  """
311
  sources_html += "</div>"
312
  return sources_html
313
 
314
+ # --- Core Async Logic ---
 
 
315
 
316
+ # NOTE: @spaces.GPU decorator is REMOVED because it's incompatible with async def
317
  async def generate_answer(prompt: str) -> str:
318
+ """Generate answer using the DeepSeek model (Async Wrapper)."""
319
+ print(f"[LLM Generate] Generating answer for prompt (length {len(prompt)})...")
320
+ start_time = time.time()
321
  try:
322
+ # Tokenize input - ensure it runs on the correct device implicitly via model.device
323
  inputs = tokenizer(
324
  prompt,
325
  return_tensors="pt",
326
  padding=True,
327
  truncation=True,
328
+ max_length=1024, # Model's context window might be larger, adjust if known
329
  return_attention_mask=True
330
  ).to(model.device)
331
 
332
+ # Use torch.inference_mode() for efficiency
333
+ with torch.inference_mode(), torch.cuda.amp.autocast(enabled=(model.dtype == torch.float16)):
334
+ # Run model.generate in a separate thread to avoid blocking asyncio event loop
335
+ outputs = await asyncio.to_thread(
336
  model.generate,
337
+ input_ids=inputs.input_ids,
338
  attention_mask=inputs.attention_mask,
339
  max_new_tokens=MAX_NEW_TOKENS,
340
  temperature=TEMPERATURE,
341
  top_p=TOP_P,
342
  pad_token_id=tokenizer.eos_token_id,
343
+ eos_token_id=tokenizer.eos_token_id, # Explicitly set EOS token
344
  do_sample=True,
 
345
  num_return_sequences=1
346
  )
347
 
348
+ # Decode only the newly generated tokens
349
+ # output_ids = outputs[0][inputs.input_ids.shape[1]:] # Slice generated part
350
+ # answer_part = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
351
+
352
+ # Alternative: Decode full output and split (can be less reliable if prompt has "Answer:")
353
  full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
354
+ answer_marker = "Answer:"
355
+ marker_index = full_output.rfind(answer_marker) # Use rfind to find the last occurrence
356
+ if marker_index != -1:
357
+ answer_part = full_output[marker_index + len(answer_marker):].strip()
358
+ else:
359
+ # Fallback: try to remove the prompt text (less reliable)
360
+ prompt_decoded = tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
361
+ if full_output.startswith(prompt_decoded):
362
+ answer_part = full_output[len(prompt_decoded):].strip()
363
+ # Check if the marker is now at the beginning
364
+ if answer_part.startswith(answer_marker):
365
+ answer_part = answer_part[len(answer_marker):].strip()
366
+ else:
367
+ print("[LLM Generate] Warning: 'Answer:' marker not found and prompt prefix mismatch. Using full output.")
368
+ answer_part = full_output # Use full output as last resort
369
+
370
+ end_time = time.time()
371
+ print(f"[LLM Generate] Answer generated successfully in {end_time - start_time:.2f}s. Length: {len(answer_part)}")
372
+ return answer_part if answer_part else "*Model did not generate a response.*"
373
+
374
  except Exception as e:
375
+ print(f"[LLM Generate] Error: {e}")
376
+ print(traceback.format_exc())
377
  return f"Error generating answer: {str(e)}"
378
 
379
+ # NOTE: @spaces.GPU decorator is REMOVED because it's incompatible with async def
 
380
  async def generate_speech(text: str, voice_id: str = 'af') -> Tuple[int, np.ndarray] | None:
381
  """Generate speech from text using Kokoro TTS model (Async Wrapper)."""
382
  global TTS_MODEL, TTS_ENABLED, VOICEPACKS
 
383
 
384
  if not TTS_ENABLED or TTS_MODEL is None:
385
+ print("[TTS Generate] Skipping: TTS not enabled or model not loaded.")
386
  return None
387
  if 'generate_tts_internal' not in globals():
388
+ print("[TTS Generate] Skipping: TTS generation function not found.")
389
+ return None
390
+ if not text or not text.strip():
391
+ print("[TTS Generate] Skipping: Empty text provided.")
392
  return None
393
 
394
+ print(f"[TTS Generate] Requesting speech for text (length {len(text)}) with voice '{voice_id}'")
395
+ start_time = time.time()
396
 
397
+ try:
398
+ device = TTS_MODEL.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
+ # Ensure voicepack is loaded
401
  if voice_id not in VOICEPACKS:
402
+ print(f"[TTS Generate] Warning: Voice '{voice_id}' not preloaded. Attempting fallback.")
403
+ # Attempt fallback to default 'af' if available
404
+ voice_id = 'af'
405
+ if 'af' not in VOICEPACKS:
406
+ print("[TTS Generate] Error: Default voice 'af' also not available. Cannot generate audio.")
407
+ return None
408
+ print("[TTS Generate] Using default voice 'af'.")
409
 
410
  # Clean the text (simple cleaning)
411
+ # Remove markdown citations like [1], [2][3] etc.
412
+ clean_text = re.sub(r'\[\d+\](\[\d+\])*', '', text)
413
+ # Remove other common markdown artifacts
414
+ clean_text = clean_text.replace('*', '').replace('#', '').replace('`', '')
415
+ # Remove excessive whitespace
416
+ clean_text = ' '.join(clean_text.split())
417
 
 
418
  if not clean_text.strip():
419
+ print("[TTS Generate] Skipping: Text is empty after cleaning.")
420
  return None
421
 
422
+ # Truncate if too long
423
  if len(clean_text) > MAX_TTS_CHARS:
424
+ print(f"[TTS Generate] Warning: Text too long ({len(clean_text)} chars), truncating to {MAX_TTS_CHARS}.")
 
425
  clean_text = clean_text[:MAX_TTS_CHARS]
426
+ # Find last punctuation or space for cleaner cut
427
+ cut_off = max(clean_text.rfind('.'), clean_text.rfind('?'), clean_text.rfind('!'), clean_text.rfind(' '))
428
+ if cut_off != -1:
429
+ clean_text = clean_text[:cut_off+1]
430
+ clean_text += "..." # Indicate truncation
431
 
432
+ print(f"[TTS Generate] Generating audio for: '{clean_text[:100]}...'")
 
433
  gen_func = globals()['generate_tts_internal']
434
+
435
+ # Run the blocking TTS generation in the thread pool executor
436
+ audio_data, _ = await asyncio.get_event_loop().run_in_executor(
437
+ executor,
438
  gen_func,
439
  TTS_MODEL,
440
  clean_text,
441
  VOICEPACKS[voice_id],
442
+ 'afr' # Language code for Kokoro (check if 'afr' or 'eng' or other is correct for your voices)
443
  )
444
 
445
  if isinstance(audio_data, torch.Tensor):
446
  # Move tensor to CPU before converting to numpy if it's not already
447
+ audio_np = audio_data.detach().cpu().numpy()
448
  elif isinstance(audio_data, np.ndarray):
449
  audio_np = audio_data
450
  else:
451
+ print("[TTS Generate] Warning: Unexpected audio data type received.")
452
  return None
453
 
454
+ end_time = time.time()
455
+ print(f"[TTS Generate] Audio generated successfully in {end_time - start_time:.2f}s. Shape: {audio_np.shape}")
456
+ # Ensure it's 1D array
457
+ if audio_np.ndim > 1:
458
+ audio_np = audio_np.flatten()
459
  return (TTS_SAMPLE_RATE, audio_np)
460
 
461
  except Exception as e:
462
+ print(f"[TTS Generate] Error: {str(e)}")
463
+ print(traceback.format_exc())
 
464
  return None
465
 
466
  # Helper to get voice ID from display name
 
468
  """Maps the user-friendly voice name to the internal voice ID."""
469
  return VOICE_CHOICES.get(voice_display_name, 'af') # Default to 'af' if not found
470
 
471
+ # --- Main Processing Logic (Async Generator) ---
472
+ import re # Import regex for cleaning
473
+
474
  async def process_query_async(query: str, history: List[List[str]], selected_voice_display_name: str):
475
  """Asynchronously process user query: search -> generate answer -> generate speech"""
476
+ print(f"\n--- New Query Processing ---")
477
+ print(f"Query: '{query}', Voice: '{selected_voice_display_name}'")
478
+
479
+ if not query or not query.strip():
480
+ print("Empty query received.")
481
  yield (
482
+ "Please enter a query.", "", gr.Button(value="Search", interactive=True), history, None
483
  )
484
  return
485
 
486
  if history is None: history = []
487
+ # Append user query to history immediately for display
488
+ current_history = history + [[query, None]] # Placeholder for assistant response
489
 
490
  # 1. Initial state: Searching
491
  yield (
492
+ "*Searching the web...*",
493
+ "<div class='searching'><span>Searching the web...</span></div>", # Added span for CSS animation
494
  gr.Button(value="Searching...", interactive=False), # Disable button
495
  current_history,
496
  None
 
502
  sources_html = format_sources(web_results)
503
 
504
  # Update state: Analyzing results
 
505
  yield (
506
+ "*Analyzing search results and generating answer...*",
507
  sources_html,
508
  gr.Button(value="Generating...", interactive=False),
509
+ current_history, # History still shows user query, assistant response is pending
510
  None
511
  )
512
 
513
  # 3. Generate Answer (non-blocking, potentially on GPU)
514
  prompt = format_prompt(query, web_results)
515
+ final_answer = await generate_answer(prompt) # This is already async
516
 
517
+ # Update history with the final answer BEFORE generating audio
518
  current_history[-1][1] = final_answer
519
+
520
+ # Update state: Answer generated, preparing audio
521
  yield (
522
  final_answer,
523
  sources_html,
524
  gr.Button(value="Audio...", interactive=False),
525
+ current_history, # Now history includes the answer
526
  None
527
  )
528
 
 
530
  audio = None
531
  tts_message = ""
532
  if not tts_thread.is_alive() and not TTS_ENABLED:
533
+ print("[TTS Status] TTS setup failed or is disabled.")
534
+ tts_message = "\n\n*(TTS is disabled or failed to initialize)*"
535
  elif tts_thread.is_alive():
536
+ print("[TTS Status] TTS is still initializing in the background.")
537
+ tts_message = "\n\n*(TTS is still initializing, audio may be delayed or unavailable)*"
538
  elif TTS_ENABLED:
539
  voice_id = get_voice_id(selected_voice_display_name)
540
+ # Only generate audio if the answer generation was successful
541
+ if not final_answer.startswith("Error"):
542
+ audio = await generate_speech(final_answer, voice_id) # This is already async
543
+ if audio is None:
544
+ print(f"[TTS Status] Audio generation failed for voice '{voice_id}'.")
545
+ tts_message = f"\n\n*(Audio generation failed)*"
546
+ else:
547
+ print("[TTS Status] Audio generated successfully.")
548
+ else:
549
+ print("[TTS Status] Skipping audio generation due to answer error.")
550
+ tts_message = "\n\n*(Audio skipped due to answer generation error)*"
551
+
552
 
553
  # 5. Final state: Show everything
554
+ print("--- Query Processing Complete ---")
555
  yield (
556
  final_answer + tts_message,
557
  sources_html,
558
  gr.Button(value="Search", interactive=True), # Re-enable button
559
+ current_history, # Final history state
560
  audio
561
  )
562
 
563
 
564
  # --- Gradio Interface ---
565
+ # (CSS remains the same as your previous version)
566
  css = """
567
+ /* ... [Your existing refined CSS] ... */
568
  .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
569
  #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
570
  #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
571
  #header h3 { color: #a8a9ab; }
572
  .search-container { background: #ffffff; border: 1px solid #e0e0e0; border-radius: 12px; box-shadow: 0 4px 16px rgba(0,0,0,0.05); padding: 1.5rem; margin-bottom: 1.5rem; }
573
+ .search-box { padding: 0; margin-bottom: 1rem; display: flex; align-items: center; }
574
+ .search-box .gradio-textbox { border-radius: 8px 0 0 8px !important; height: 44px !important; flex-grow: 1; }
575
+ .search-box .gradio-dropdown { border-radius: 0 !important; margin-left: -1px; margin-right: -1px; height: 44px !important; width: 180px; flex-shrink: 0; }
576
+ .search-box .gradio-button { border-radius: 0 8px 8px 0 !important; height: 44px !important; flex-shrink: 0; }
577
+ .search-box input[type="text"] { background: #f7f7f8 !important; border: 1px solid #d1d5db !important; color: #1f2937 !important; transition: all 0.3s ease; height: 100% !important; padding: 0 12px !important;}
578
+ .search-box input[type="text"]:focus { border-color: #2563eb !important; box-shadow: 0 0 0 2px rgba(37, 99, 235, 0.2) !important; background: white !important; z-index: 1; }
579
  .search-box input[type="text"]::placeholder { color: #9ca3af !important; }
580
+ .search-box button { background: #2563eb !important; border: none !important; color: white !important; box-shadow: 0 1px 2px rgba(0,0,0,0.05) !important; transition: all 0.3s ease !important; height: 100% !important; }
581
  .search-box button:hover { background: #1d4ed8 !important; }
582
  .search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
583
  .results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
 
591
  .source-item:last-child { border-bottom: none; }
592
  /* .source-item:hover { background-color: #f9fafb; } */
593
  .source-number { font-weight: bold; margin-right: 12px; color: #6b7280; width: 20px; text-align: right; flex-shrink: 0;}
594
+ .source-content { flex: 1; min-width: 0;} /* Allow content to shrink */
595
+ .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;}
596
  .source-title:hover { color: #1d4ed8; text-decoration: underline; }
597
  .source-date { color: #6b7280; font-size: 0.8em; margin-left: 8px; }
598
  .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
 
620
  .markdown-content th { background: #f9fafb !important; font-weight: 600; }
621
  .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; }
622
  .accordion > .label-wrap { padding: 10px 15px !important; } /* Style accordion header */
623
+ .voice-selector { margin: 0; padding: 0; height: 100%; }
624
+ .voice-selector div[data-testid="dropdown"] { height: 100% !important; border-radius: 0 !important;}
625
  .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
626
+ .voice-selector select:focus { border-color: #2563eb !important; box-shadow: none !important; z-index: 1; position: relative;}
627
  .audio-player { margin-top: 1rem; background: #f9fafb !important; border-radius: 8px !important; padding: 0.5rem !important; border: 1px solid #e5e7eb;}
628
  .audio-player audio { width: 100% !important; }
629
  .searching, .error { padding: 1rem; border-radius: 8px; text-align: center; margin: 1rem 0; border: 1px dashed; }
 
631
  .error { background: #fef2f2; color: #ef4444; border-color: #fecaca; }
632
  .no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
633
  @keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
634
+ .searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; }
635
+ /* Dark Mode Styles */
636
  .dark .gradio-container { background-color: #111827 !important; }
637
  .dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
638
  .dark #header h3 { color: #9ca3af; }
 
654
  .dark .source-title { color: #60a5fa; }
655
  .dark .source-title:hover { color: #93c5fd; }
656
  .dark .source-snippet { color: #d1d5db; }
657
+ .dark .chat-history { background: #374151; border-color: #4b5563; scrollbar-color: #4b5563 #374151; color: #d1d5db;} /* Ensure chat text is visible */
658
  .dark .chat-history::-webkit-scrollbar-track { background: #374151; }
659
  .dark .chat-history::-webkit-scrollbar-thumb { background-color: #4b5563; }
660
  .dark .examples-container { background: #374151; border-color: #4b5563; }
 
671
  .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
672
  .dark .markdown-content th { background: #374151 !important; }
673
  .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; }
674
+ .dark .accordion > .label-wrap { color: #d1d5db !important; } /* Accordion label color */
675
  .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
676
  .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
677
  .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
678
+ .dark .audio-player audio::-webkit-media-controls-panel { background-color: #374151; } /* Style audio player controls */
679
+ .dark .audio-player audio::-webkit-media-controls-play-button { color: #d1d5db; }
680
+ .dark .audio-player audio::-webkit-media-controls-current-time-display { color: #9ca3af; }
681
+ .dark .audio-player audio::-webkit-media-controls-time-remaining-display { color: #9ca3af; }
682
  .dark .searching { background: #1e3a8a; color: #93c5fd; border-color: #3b82f6; }
683
  .dark .error { background: #7f1d1d; color: #fca5a5; border-color: #ef4444; }
684
  .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
 
685
  """
686
 
687
  with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
688
+ # chat_history state persists across interactions for a single user session
689
  chat_history = gr.State([])
690
 
691
+ with gr.Column(): # Main container for vertical layout
692
+ # Header Section
693
  with gr.Column(elem_id="header"):
694
  gr.Markdown("# 🔍 AI Search Assistant")
695
  gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
696
 
697
+ # Search Input and Controls Section
698
  with gr.Column(elem_classes="search-container"):
699
+ with gr.Row(elem_classes="search-box", equal_height=False): # Use Row for horizontal elements
700
  search_input = gr.Textbox(
701
  label="",
702
  placeholder="Ask anything...",
703
+ scale=5, # Takes more horizontal space
704
+ container=False, # Important for direct styling within Row
705
  elem_classes="gradio-textbox"
706
  )
707
  voice_select = gr.Dropdown(
708
  choices=list(VOICE_CHOICES.keys()),
709
+ value=list(VOICE_CHOICES.keys())[0], # Default voice display name
710
+ label="", # Visually hidden label
711
+ scale=1, # Takes less space
712
+ min_width=180, # Fixed width for dropdown
713
  container=False, # Important
714
  elem_classes="voice-selector gradio-dropdown"
715
  )
716
  search_btn = gr.Button(
717
  "Search",
718
  variant="primary",
719
+ scale=0, # Minimal width needed for text
720
+ min_width=100,
721
  elem_classes="gradio-button"
722
  )
723
 
724
+ # Results Display Section (using Columns for side-by-side layout)
725
  with gr.Row(elem_classes="results-container", equal_height=False):
726
+ # Left Column: Answer and Chat History
727
+ with gr.Column(scale=3): # Takes 3 parts of the width
728
  with gr.Column(elem_classes="answer-box"):
729
+ answer_output = gr.Markdown(value="*Your answer will appear here...*", elem_classes="markdown-content")
730
+ # Audio player below the answer text
731
+ audio_output = gr.Audio(
732
+ label="Voice Response",
733
+ type="numpy", # Expects (rate, numpy_array) tuple
734
+ autoplay=False, # Don't autoplay by default
735
+ show_label=False, # Hide the "Voice Response" label visually
736
+ elem_classes="audio-player"
737
+ )
738
 
739
  with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
740
+ chat_history_display = gr.Chatbot(
741
+ label="Conversation",
742
+ bubble_full_width=True, # Bubbles take full width
743
+ height=400,
744
+ elem_classes="chat-history"
745
+ )
746
+
747
+ # Right Column: Sources
748
+ with gr.Column(scale=2): # Takes 2 parts of the width
749
  with gr.Column(elem_classes="sources-box"):
750
  gr.Markdown("### Sources")
751
  sources_output = gr.HTML(value="<div class='no-sources'>Sources will appear here after searching.</div>")
752
 
753
+ # Example Prompts Section
754
  with gr.Row(elem_classes="examples-container"):
755
  gr.Examples(
756
  examples=[
757
  "Latest news about renewable energy",
758
  "Explain the concept of Large Language Models (LLMs)",
759
  "What are the symptoms and prevention tips for the flu?",
760
+ "Compare Python and JavaScript for web development",
761
+ "Summarize the main points of the Paris Agreement on climate change",
762
  ],
763
+ inputs=search_input, # Clicking example populates this input
764
  label="Try these examples:",
765
  elem_classes="gradio-examples" # Add class for potential styling
766
  )
767
 
768
  # --- Event Handling ---
 
769
  async def handle_interaction(query, history, voice_display_name):
770
+ """Wrapper to handle the async generator and update outputs."""
771
+ print(f"[Interaction] Handling query: '{query}'")
772
+ outputs = { # Dictionary to hold the latest state of outputs
773
+ "answer": "...",
774
+ "sources": "...",
775
+ "button": gr.Button(value="Search", interactive=True),
776
+ "history": history,
777
+ "audio": None
778
+ }
779
  try:
780
+ # Iterate through the updates yielded by the async generator
781
+ async for update_tuple in process_query_async(query, history, voice_display_name):
782
+ # Unpack the tuple
783
+ ans_out, src_out, btn_state, hist_display, aud_out = update_tuple
784
+ # Update the outputs dictionary
785
+ outputs["answer"] = ans_out
786
+ outputs["sources"] = src_out
787
+ outputs["button"] = btn_state # Can be a gr.Button update dict or object
788
+ outputs["history"] = hist_display
789
+ outputs["audio"] = aud_out
790
+ # Yield the current state of all outputs
791
+ yield outputs["answer"], outputs["sources"], outputs["button"], outputs["history"], outputs["audio"]
792
  except Exception as e:
793
+ print(f"[Interaction] Error: {e}")
794
+ print(traceback.format_exc())
 
795
  error_message = f"An unexpected error occurred: {e}"
796
  # Provide a final error state update
797
+ final_error_history = history + [[query, f"*Error: {error_message}*"]] if query else history
798
  yield (
799
  error_message,
800
+ "<div class='error'>Error processing request. Please check logs or try again.</div>",
801
  gr.Button(value="Search", interactive=True), # Re-enable button on error
802
+ final_error_history,
803
  None
804
  )
805
 
806
+ # Connect the handle_interaction function to the button click and input submit events
807
+ outputs_list = [answer_output, sources_output, search_btn, chat_history_display, audio_output]
808
+ inputs_list = [search_input, chat_history, voice_select] # Pass the dropdown component itself
809
 
 
810
  search_btn.click(
811
  fn=handle_interaction,
812
+ inputs=inputs_list,
813
+ outputs=outputs_list
814
  )
815
 
816
  search_input.submit(
817
  fn=handle_interaction,
818
+ inputs=inputs_list,
819
+ outputs=outputs_list
820
  )
821
 
822
  if __name__ == "__main__":
823
+ print("Starting Gradio application...")
824
+ # Launch the app with queuing enabled for handling multiple users
825
+ demo.queue(max_size=20).launch(
826
+ debug=True, # Enable Gradio debug mode for more logs
827
+ share=True, # Create a public link (useful for Spaces)
828
+ # server_name="0.0.0.0" # Bind to all interfaces if running locally and need external access
829
+ )