sagar007 commited on
Commit
3d63694
·
verified ·
1 Parent(s): b8c63a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +569 -601
app.py CHANGED
@@ -13,13 +13,17 @@ from functools import lru_cache
13
  import asyncio
14
  import threading
15
  from concurrent.futures import ThreadPoolExecutor
 
 
 
 
16
 
17
  # --- Configuration ---
18
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
19
  MAX_SEARCH_RESULTS = 5
20
  TTS_SAMPLE_RATE = 24000
21
- MAX_TTS_CHARS = 1000
22
- GPU_DURATION = 30 # for spaces.GPU decorator
23
  MAX_NEW_TOKENS = 256
24
  TEMPERATURE = 0.7
25
  TOP_P = 0.95
@@ -30,18 +34,25 @@ try:
30
  print("Loading tokenizer...")
31
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
32
  tokenizer.pad_token = tokenizer.eos_token
33
-
34
  print("Loading model...")
 
 
 
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
  MODEL_NAME,
37
- device_map="auto",
38
- offload_folder="offload",
39
  low_cpu_mem_usage=True,
40
- torch_dtype=torch.float16
41
  )
 
42
  print("Model and tokenizer loaded successfully")
43
  except Exception as e:
44
  print(f"Error initializing model: {str(e)}")
 
 
45
  raise
46
 
47
  # --- TTS Setup ---
@@ -54,82 +65,149 @@ VOICE_CHOICES = {
54
  TTS_ENABLED = False
55
  TTS_MODEL = None
56
  VOICEPACKS = {} # Cache voice packs
 
57
 
58
  # Initialize Kokoro TTS in a separate thread to avoid blocking startup
59
  def setup_tts():
60
  global TTS_ENABLED, TTS_MODEL, VOICEPACKS
61
-
62
  try:
63
- # Install dependencies first
64
- subprocess.run(['git', 'lfs', 'install'], check=True)
65
- if not os.path.exists('Kokoro-82M'):
66
- subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
67
-
68
- # Install espeak
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  try:
70
- subprocess.run(['apt-get', 'update'], check=True)
71
- subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
72
- except subprocess.CalledProcessError:
 
 
 
73
  try:
74
- subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
75
- except subprocess.CalledProcessError:
76
- print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
77
-
 
 
 
78
  # Set up Kokoro TTS
79
- if os.path.exists('Kokoro-82M'):
80
  import sys
81
- sys.path.append('Kokoro-82M')
82
- from models import build_model
83
- from kokoro import generate
84
-
85
- # Make these functions accessible globally
86
- globals()['build_model'] = build_model
87
- globals()['generate_tts'] = generate
88
-
89
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
- TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
91
-
92
- # Preload default voice
93
- default_voice = 'af'
94
- VOICEPACKS[default_voice] = torch.load(f'Kokoro-82M/voices/{default_voice}.pt',
95
- map_location=device,
96
- weights_only=True)
97
-
98
- # Preload other common voices to reduce latency
99
- for voice_name in ['af_bella', 'af_sarah', 'af_nicole']:
100
- try:
101
- voice_path = f'Kokoro-82M/voices/{voice_name}.pt'
102
- if os.path.exists(voice_path):
103
- VOICEPACKS[voice_name] = torch.load(voice_path,
104
- map_location=device,
105
- weights_only=True)
106
- except Exception as e:
107
- print(f"Warning: Could not preload voice {voice_name}: {str(e)}")
108
-
109
- TTS_ENABLED = True
110
- print("TTS setup completed successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  else:
112
- print("Warning: Kokoro-82M directory not found. TTS disabled.")
 
 
 
 
 
113
  except Exception as e:
114
- print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
115
  TTS_ENABLED = False
116
 
117
  # Start TTS setup in a separate thread
118
- threading.Thread(target=setup_tts, daemon=True).start()
 
 
119
 
120
  # --- Search and Generation Functions ---
121
  @lru_cache(maxsize=128)
122
  def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
123
  """Get web search results using DuckDuckGo with caching for improved performance"""
 
124
  try:
125
  with DDGS() as ddgs:
126
- results = list(ddgs.text(query, max_results=max_results))
127
- return [{
128
- "title": result.get("title", ""),
129
- "snippet": result.get("body", ""),
130
- "url": result.get("href", ""),
131
- "date": result.get("published", "")
132
- } for result in results]
 
 
 
 
 
 
133
  except Exception as e:
134
  print(f"Error in web search: {e}")
135
  return []
@@ -137,32 +215,38 @@ def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[D
137
  def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
138
  """Format the prompt with web context"""
139
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
140
- context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
141
- return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
 
 
 
142
  Current Time: {current_time}
143
- Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
144
- Query: {query}
145
  Web Context:
146
- {context_lines}
147
- Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
 
 
148
  Answer:"""
 
 
149
 
150
  def format_sources(web_results: List[Dict[str, str]]) -> str:
151
  """Format sources with more details"""
152
  if not web_results:
153
- return "<div class='no-sources'>No sources available</div>"
154
 
155
  sources_html = "<div class='sources-container'>"
156
  for i, res in enumerate(web_results, 1):
157
- title = res["title"] or "Source"
158
- date = f"<span class='source-date'>{res['date']}</span>" if res.get('date') else ""
159
- snippet = res.get("snippet", "")[:150] + "..." if res.get("snippet") else ""
 
160
  sources_html += f"""
161
  <div class='source-item'>
162
  <div class='source-number'>[{i}]</div>
163
  <div class='source-content'>
164
- <a href="{res['url']}" target="_blank" class='source-title'>{title}</a>
165
- {date}
166
  <div class='source-snippet'>{snippet}</div>
167
  </div>
168
  </div>
@@ -170,566 +254,450 @@ def format_sources(web_results: List[Dict[str, str]]) -> str:
170
  sources_html += "</div>"
171
  return sources_html
172
 
173
- @spaces.GPU(duration=GPU_DURATION)
174
- def generate_answer(prompt: str) -> str:
175
- """Generate answer using the DeepSeek model with optimized settings"""
176
- inputs = tokenizer(
177
- prompt,
178
- return_tensors="pt",
179
- padding=True,
180
- truncation=True,
181
- max_length=512,
182
- return_attention_mask=True
183
- ).to(model.device)
184
-
185
- with torch.no_grad(): # Disable gradient calculation for inference
186
- outputs = model.generate(
187
- inputs.input_ids,
188
- attention_mask=inputs.attention_mask,
189
- max_new_tokens=MAX_NEW_TOKENS,
190
- temperature=TEMPERATURE,
191
- top_p=TOP_P,
192
- pad_token_id=tokenizer.eos_token_id,
193
- do_sample=True,
194
- early_stopping=True
195
- )
196
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- @spaces.GPU(duration=GPU_DURATION)
199
- def generate_speech(text: str, voice_name: str = 'af') -> Tuple[int, np.ndarray] | None:
200
- """Generate speech from text using Kokoro TTS model with improved error handling and caching."""
201
- global VOICEPACKS, TTS_MODEL, TTS_ENABLED
202
-
203
  if not TTS_ENABLED or TTS_MODEL is None:
 
 
 
 
204
  return None
205
 
206
  try:
207
- from kokoro import generate as generate_tts
208
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
209
-
210
- # Load voicepack if needed
211
- if voice_name not in VOICEPACKS:
212
- voice_file = f'Kokoro-82M/voices/{voice_name}.pt'
213
-
214
- if not os.path.exists(voice_file):
215
- print(f"Voicepack {voice_name}.pt not found. Falling back to default 'af'.")
216
- voice_name = 'af'
217
-
218
- # Check if default is already loaded
219
- if voice_name not in VOICEPACKS:
220
- voice_file = f'Kokoro-82M/voices/{voice_name}.pt'
221
- if os.path.exists(voice_file):
222
- VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True)
223
- else:
224
- print("Default voicepack 'af.pt' not found. Cannot generate audio.")
225
- return None
 
 
 
 
 
226
  else:
227
- VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True)
228
-
229
- # Clean the text
230
- clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
231
- clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
232
-
233
- # Split long text into chunks
234
- max_chars = MAX_TTS_CHARS
235
- chunks = []
236
- if len(clean_text) > max_chars:
237
- sentences = clean_text.split('.')
238
- current_chunk = ""
239
- for sentence in sentences:
240
- if len(current_chunk) + len(sentence) + 1 < max_chars:
241
- current_chunk += sentence + "."
242
- else:
243
- chunks.append(current_chunk.strip())
244
- current_chunk = sentence + "."
245
- if current_chunk:
246
- chunks.append(current_chunk.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  else:
248
- chunks = [clean_text]
249
-
250
- # Generate audio for each chunk
251
- audio_chunks = []
252
- for chunk in chunks:
253
- if chunk.strip():
254
- chunk_audio, _ = generate_tts(TTS_MODEL, chunk, VOICEPACKS[voice_name], lang='a')
255
- if isinstance(chunk_audio, torch.Tensor):
256
- chunk_audio = chunk_audio.cpu().numpy()
257
- audio_chunks.append(chunk_audio)
258
-
259
- # Concatenate chunks
260
- if audio_chunks:
261
- final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
262
- return (TTS_SAMPLE_RATE, final_audio)
263
-
264
- return None
265
 
266
  except Exception as e:
 
267
  print(f"Error generating speech: {str(e)}")
 
268
  return None
269
 
270
- # --- Asynchronous Processing ---
271
- async def async_web_search(query: str) -> List[Dict[str, str]]:
272
- """Run web search in a non-blocking way"""
273
- loop = asyncio.get_event_loop()
274
- return await loop.run_in_executor(None, get_web_results, query)
275
-
276
- async def async_answer_generation(prompt: str) -> str:
277
- """Run answer generation in a non-blocking way"""
278
- loop = asyncio.get_event_loop()
279
- return await loop.run_in_executor(None, generate_answer, prompt)
280
-
281
- async def async_speech_generation(text: str, voice_name: str) -> Tuple[int, np.ndarray] | None:
282
- """Run speech generation in a non-blocking way"""
283
- loop = asyncio.get_event_loop()
284
- return await loop.run_in_executor(None, generate_speech, text, voice_name)
285
-
286
- def process_query(query: str, history: List[List[str]], selected_voice: str = 'af'):
287
- """Process user query with streaming effect and non-blocking operations"""
288
- try:
289
- if history is None:
290
- history = []
291
-
292
- # Start the search task
293
- current_history = history + [[query, "*Searching...*"]]
294
-
295
- # Yield initial searching state
296
- yield (
297
- "*Searching & Thinking...*", # answer_output (Markdown)
298
- "<div class='searching'>Searching for results...</div>", # sources_output (HTML)
299
- "Searching...", # search_btn (Button)
300
- current_history, # chat_history_display (Chatbot)
301
- None # audio_output (Audio)
302
- )
303
 
304
- # Get web results
305
- web_results = get_web_results(query)
306
- sources_html = format_sources(web_results)
307
-
308
- # Update with the search results obtained
309
  yield (
310
- "*Analyzing search results...*", # answer_output
311
- sources_html, # sources_output
312
- "Generating answer...", # search_btn
313
- current_history, # chat_history_display
314
- None # audio_output
315
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- # Generate answer
318
- prompt = format_prompt(query, web_results)
319
- answer = generate_answer(prompt)
320
- final_answer = answer.split("Answer:")[-1].strip()
321
-
322
- # Update history before TTS
323
- updated_history = history + [[query, final_answer]]
324
-
325
- # Update with the answer before generating speech
326
- yield (
327
- final_answer, # answer_output
328
- sources_html, # sources_output
329
- "Generating audio...", # search_btn
330
- updated_history, # chat_history_display
331
- None # audio_output
332
- )
333
 
334
- # Generate speech (but don't block if TTS is still initializing)
335
- audio = None
336
- if TTS_ENABLED and TTS_MODEL is not None:
337
- try:
338
- audio = generate_speech(final_answer, selected_voice)
339
- if audio is None:
340
- final_answer += "\n\n*Audio generation failed. The voicepack may be missing or incompatible.*"
341
- except Exception as e:
342
- final_answer += f"\n\n*Error generating audio: {str(e)}*"
343
- else:
344
- final_answer += "\n\n*TTS is still initializing or is disabled. Try again in a moment.*"
 
 
345
 
346
- # Yield final result
347
- yield (
348
- final_answer, # answer_output
349
- sources_html, # sources_output
350
- "Search", # search_btn
351
- updated_history, # chat_history_display
352
- audio # audio_output
353
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
- except Exception as e:
356
- error_message = str(e)
357
- if "GPU quota" in error_message:
358
- error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
359
- yield (
360
- f"Error: {error_message}", # answer_output
361
- "<div class='error'>An error occurred during search</div>", # sources_output
362
- "Search", # search_btn
363
- history + [[query, f"*Error: {error_message}*"]], # chat_history_display
364
- None # audio_output
365
- )
366
 
367
- # --- Improved UI ---
368
  css = """
369
- .gradio-container {
370
- max-width: 1200px !important;
371
- background-color: #f7f7f8 !important;
372
- }
373
- #header {
374
- text-align: center;
375
- margin-bottom: 2rem;
376
- padding: 2rem 0;
377
- background: linear-gradient(135deg, #1a1b1e, #2d2e32);
378
- border-radius: 12px;
379
- color: white;
380
- box-shadow: 0 8px 32px rgba(0,0,0,0.2);
381
- }
382
- #header h1 {
383
- color: white;
384
- font-size: 2.5rem;
385
- margin-bottom: 0.5rem;
386
- text-shadow: 0 2px 4px rgba(0,0,0,0.3);
387
- }
388
- #header h3 {
389
- color: #a8a9ab;
390
- }
391
- .search-container {
392
- background: linear-gradient(135deg, #1a1b1e, #2d2e32);
393
- border-radius: 12px;
394
- box-shadow: 0 4px 16px rgba(0,0,0,0.15);
395
- padding: 1.5rem;
396
- margin-bottom: 1.5rem;
397
- }
398
- .search-box {
399
- padding: 1rem;
400
- background: #2c2d30;
401
- border-radius: 10px;
402
- margin-bottom: 1rem;
403
- box-shadow: inset 0 2px 4px rgba(0,0,0,0.1);
404
- }
405
- .search-box input[type="text"] {
406
- background: #3a3b3e !important;
407
- border: 1px solid #4a4b4e !important;
408
- color: white !important;
409
- border-radius: 8px !important;
410
- transition: all 0.3s ease;
411
- }
412
- .search-box input[type="text"]:focus {
413
- border-color: #60a5fa !important;
414
- box-shadow: 0 0 0 2px rgba(96, 165, 250, 0.3) !important;
415
- }
416
- .search-box input[type="text"]::placeholder {
417
- color: #a8a9ab !important;
418
- }
419
- .search-box button {
420
- background: #2563eb !important;
421
- border: none !important;
422
- box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important;
423
- transition: all 0.3s ease !important;
424
- }
425
- .search-box button:hover {
426
- background: #1d4ed8 !important;
427
- transform: translateY(-1px) !important;
428
- }
429
- .search-box button:active {
430
- transform: translateY(1px) !important;
431
- }
432
- .results-container {
433
- background: #2c2d30;
434
- border-radius: 10px;
435
- padding: 1.5rem;
436
- margin-top: 1.5rem;
437
- box-shadow: 0 4px 12px rgba(0,0,0,0.1);
438
- }
439
- .answer-box {
440
- background: #3a3b3e;
441
- border-radius: 10px;
442
- padding: 1.5rem;
443
- color: white;
444
- margin-bottom: 1.5rem;
445
- box-shadow: 0 2px 8px rgba(0,0,0,0.15);
446
- transition: all 0.3s ease;
447
- }
448
- .answer-box:hover {
449
- box-shadow: 0 4px 16px rgba(0,0,0,0.2);
450
- }
451
- .answer-box p {
452
- color: #e5e7eb;
453
- line-height: 1.7;
454
- }
455
- .answer-box code {
456
- background: #2c2d30;
457
- border-radius: 4px;
458
- padding: 2px 4px;
459
- }
460
- .sources-container {
461
- margin-top: 1rem;
462
- background: #2c2d30;
463
- border-radius: 8px;
464
- padding: 1rem;
465
- }
466
- .source-item {
467
- display: flex;
468
- padding: 12px;
469
- margin: 12px 0;
470
- background: #3a3b3e;
471
- border-radius: 8px;
472
- transition: all 0.2s;
473
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
474
- }
475
- .source-item:hover {
476
- background: #4a4b4e;
477
- transform: translateY(-2px);
478
- box-shadow: 0 4px 8px rgba(0,0,0,0.15);
479
- }
480
- .source-number {
481
- font-weight: bold;
482
- margin-right: 12px;
483
- color: #60a5fa;
484
- }
485
- .source-content {
486
- flex: 1;
487
- }
488
- .source-title {
489
- color: #60a5fa;
490
- font-weight: 500;
491
- text-decoration: none;
492
- display: block;
493
- margin-bottom: 6px;
494
- transition: all 0.2s;
495
- }
496
- .source-title:hover {
497
- color: #93c5fd;
498
- text-decoration: underline;
499
- }
500
- .source-date {
501
- color: #a8a9ab;
502
- font-size: 0.9em;
503
- margin-left: 8px;
504
- }
505
- .source-snippet {
506
- color: #e5e7eb;
507
- font-size: 0.9em;
508
- line-height: 1.5;
509
- }
510
- .chat-history {
511
- max-height: 400px;
512
- overflow-y: auto;
513
- padding: 1rem;
514
- background: #2c2d30;
515
- border-radius: 8px;
516
- margin-top: 1rem;
517
- scrollbar-width: thin;
518
- scrollbar-color: #4a4b4e #2c2d30;
519
- }
520
- .chat-history::-webkit-scrollbar {
521
- width: 8px;
522
- }
523
- .chat-history::-webkit-scrollbar-track {
524
- background: #2c2d30;
525
- }
526
- .chat-history::-webkit-scrollbar-thumb {
527
- background-color: #4a4b4e;
528
- border-radius: 20px;
529
- }
530
- .examples-container {
531
- background: #2c2d30;
532
- border-radius: 8px;
533
- padding: 1rem;
534
- margin-top: 1rem;
535
- }
536
- .examples-container button {
537
- background: #3a3b3e !important;
538
- border: 1px solid #4a4b4e !important;
539
- color: #e5e7eb !important;
540
- transition: all 0.2s;
541
- margin: 4px !important;
542
- }
543
- .examples-container button:hover {
544
- background: #4a4b4e !important;
545
- transform: translateY(-1px);
546
- }
547
- .markdown-content {
548
- color: #e5e7eb !important;
549
- }
550
- .markdown-content h1, .markdown-content h2, .markdown-content h3 {
551
- color: white !important;
552
- margin-top: 1.2em !important;
553
- margin-bottom: 0.8em !important;
554
- }
555
- .markdown-content h1 {
556
- font-size: 1.7em !important;
557
- }
558
- .markdown-content h2 {
559
- font-size: 1.5em !important;
560
- }
561
- .markdown-content h3 {
562
- font-size: 1.3em !important;
563
- }
564
- .markdown-content a {
565
- color: #60a5fa !important;
566
- text-decoration: none !important;
567
- transition: all 0.2s;
568
- }
569
- .markdown-content a:hover {
570
- color: #93c5fd !important;
571
- text-decoration: underline !important;
572
- }
573
- .markdown-content code {
574
- background: #2c2d30 !important;
575
- padding: 2px 6px !important;
576
- border-radius: 4px !important;
577
- font-family: monospace !important;
578
- }
579
- .markdown-content pre {
580
- background: #2c2d30 !important;
581
- padding: 12px !important;
582
- border-radius: 8px !important;
583
- overflow-x: auto !important;
584
- }
585
- .markdown-content blockquote {
586
- border-left: 4px solid #60a5fa !important;
587
- padding-left: 1em !important;
588
- margin-left: 0 !important;
589
- color: #a8a9ab !important;
590
- }
591
- .markdown-content table {
592
- border-collapse: collapse !important;
593
- width: 100% !important;
594
- }
595
- .markdown-content th, .markdown-content td {
596
- padding: 8px 12px !important;
597
- border: 1px solid #4a4b4e !important;
598
- }
599
- .markdown-content th {
600
- background: #2c2d30 !important;
601
- }
602
- .accordion {
603
- background: #2c2d30 !important;
604
- border-radius: 8px !important;
605
- margin-top: 1rem !important;
606
- box-shadow: 0 2px 8px rgba(0,0,0,0.1) !important;
607
- }
608
- .voice-selector {
609
- margin-top: 1rem;
610
- background: #2c2d30;
611
- border-radius: 8px;
612
- padding: 0.5rem;
613
- }
614
- .voice-selector select {
615
- background: #3a3b3e !important;
616
- color: white !important;
617
- border: 1px solid #4a4b4e !important;
618
- border-radius: 4px !important;
619
- padding: 8px !important;
620
- transition: all 0.2s;
621
- }
622
- .voice-selector select:focus {
623
- border-color: #60a5fa !important;
624
- }
625
- .audio-player {
626
- margin-top: 1rem;
627
- background: #2c2d30 !important;
628
- border-radius: 8px !important;
629
- padding: 0.5rem !important;
630
- }
631
- .audio-player audio {
632
- width: 100% !important;
633
- }
634
- .searching, .error {
635
- padding: 1rem;
636
- border-radius: 8px;
637
- text-align: center;
638
- margin: 1rem 0;
639
- }
640
- .searching {
641
- background: rgba(96, 165, 250, 0.1);
642
- color: #60a5fa;
643
- }
644
- .error {
645
- background: rgba(239, 68, 68, 0.1);
646
- color: #ef4444;
647
- }
648
- .no-sources {
649
- padding: 1rem;
650
- text-align: center;
651
- color: #a8a9ab;
652
- background: #2c2d30;
653
- border-radius: 8px;
654
- }
655
- @keyframes pulse {
656
- 0% { opacity: 0.6; }
657
- 50% { opacity: 1; }
658
- 100% { opacity: 0.6; }
659
- }
660
- .searching {
661
- animation: pulse 1.5s infinite;
662
- }
663
  """
664
 
665
- # --- Gradio Interface ---
666
- with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
667
  chat_history = gr.State([])
668
-
669
- with gr.Column(elem_id="header"):
670
- gr.Markdown("# 🔍 AI Search Assistant")
671
- gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
672
-
673
- with gr.Column(elem_classes="search-container"):
674
- with gr.Row(elem_classes="search-box"):
675
- search_input = gr.Textbox(
676
- label="",
677
- placeholder="Ask anything...",
678
- scale=5,
679
- container=False
680
- )
681
- voice_select = gr.Dropdown(
682
- choices=list(VOICE_CHOICES.keys()),
683
- value=list(VOICE_CHOICES.keys())[0],
684
- label="Voice",
685
- elem_classes="voice-selector",
686
- scale=1
687
- )
688
- search_btn = gr.Button("Search", variant="primary", scale=1)
689
-
690
- with gr.Row(elem_classes="results-container"):
691
- with gr.Column(scale=2):
692
- with gr.Column(elem_classes="answer-box"):
693
- answer_output = gr.Markdown(elem_classes="markdown-content")
694
- with gr.Row():
695
- audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player")
696
- with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
697
- chat_history_display = gr.Chatbot(elem_classes="chat-history")
698
- with gr.Column(scale=1):
699
- with gr.Column(elem_classes="sources-box"):
700
- gr.Markdown("### Sources")
701
- sources_output = gr.HTML()
702
-
703
- with gr.Row(elem_classes="examples-container"):
704
- gr.Examples(
705
- examples=[
706
- "Latest news about artificial intelligence advances",
707
- "How does blockchain technology work?",
708
- "What are the best practices for sustainable living?",
709
- "Compare electric vehicles and traditional cars"
710
- ],
711
- inputs=search_input,
712
- label="Try these examples"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  )
714
 
715
- # Handle voice selection mapping
716
- def get_voice_id(voice_name):
717
- return VOICE_CHOICES.get(voice_name, 'af')
718
 
719
- # Handle interactions
720
  search_btn.click(
721
- fn=process_query,
722
- inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select],
723
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
724
  )
725
-
726
- # Also trigger search on Enter key
727
  search_input.submit(
728
- fn=process_query,
729
- inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select],
730
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
731
  )
732
 
733
  if __name__ == "__main__":
734
- # Start the app with optimized settings
735
- demo.queue(concurrency_count=5, max_size=20).launch(share=True)
 
13
  import asyncio
14
  import threading
15
  from concurrent.futures import ThreadPoolExecutor
16
+ import warnings
17
+
18
+ # Suppress specific warnings if needed (optional)
19
+ warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
20
 
21
  # --- Configuration ---
22
  MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
23
  MAX_SEARCH_RESULTS = 5
24
  TTS_SAMPLE_RATE = 24000
25
+ MAX_TTS_CHARS = 1000 # Reduced for faster testing, adjust as needed
26
+ GPU_DURATION = 60 # Increased duration for longer tasks like TTS
27
  MAX_NEW_TOKENS = 256
28
  TEMPERATURE = 0.7
29
  TOP_P = 0.95
 
34
  print("Loading tokenizer...")
35
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
36
  tokenizer.pad_token = tokenizer.eos_token
37
+
38
  print("Loading model...")
39
+ # Determine device map based on CUDA availability
40
+ device_map = "auto" if torch.cuda.is_available() else {"": "cpu"}
41
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Use float32 on CPU
42
+
43
  model = AutoModelForCausalLM.from_pretrained(
44
  MODEL_NAME,
45
+ device_map=device_map,
46
+ # offload_folder="offload", # Only use offload if really needed and configured
47
  low_cpu_mem_usage=True,
48
+ torch_dtype=torch_dtype
49
  )
50
+ print(f"Model loaded on device map: {model.hf_device_map}")
51
  print("Model and tokenizer loaded successfully")
52
  except Exception as e:
53
  print(f"Error initializing model: {str(e)}")
54
+ # If running in Spaces, maybe try loading to CPU as fallback?
55
+ # For now, just raise the error.
56
  raise
57
 
58
  # --- TTS Setup ---
 
65
  TTS_ENABLED = False
66
  TTS_MODEL = None
67
  VOICEPACKS = {} # Cache voice packs
68
+ KOKORO_PATH = 'Kokoro-82M'
69
 
70
  # Initialize Kokoro TTS in a separate thread to avoid blocking startup
71
  def setup_tts():
72
  global TTS_ENABLED, TTS_MODEL, VOICEPACKS
73
+
74
  try:
75
+ # Check if Kokoro already exists
76
+ if not os.path.exists(KOKORO_PATH):
77
+ print("Cloning Kokoro-82M repository...")
78
+ # Install git-lfs if not present (might need sudo/apt)
79
+ try:
80
+ subprocess.run(['git', 'lfs', 'install'], check=True, capture_output=True)
81
+ except (FileNotFoundError, subprocess.CalledProcessError) as lfs_err:
82
+ print(f"Warning: git-lfs might not be installed or failed: {lfs_err}. Cloning might be slow or incomplete.")
83
+
84
+ clone_cmd = ['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M']
85
+ result = subprocess.run(clone_cmd, check=True, capture_output=True, text=True)
86
+ print("Kokoro cloned successfully.")
87
+ print(result.stdout)
88
+ # Optionally pull LFS files if needed (sometimes clone doesn't get them all)
89
+ # subprocess.run(['git', 'lfs', 'pull'], cwd=KOKORO_PATH, check=True)
90
+
91
+ else:
92
+ print("Kokoro-82M directory already exists.")
93
+
94
+ # Install espeak (essential for phonemization)
95
+ print("Attempting to install espeak-ng or espeak...")
96
  try:
97
+ # Try installing espeak-ng first (often preferred)
98
+ subprocess.run(['sudo', 'apt-get', 'update'], check=True, capture_output=True)
99
+ subprocess.run(['sudo', 'apt-get', 'install', '-y', 'espeak-ng'], check=True, capture_output=True)
100
+ print("espeak-ng installed successfully.")
101
+ except (FileNotFoundError, subprocess.CalledProcessError):
102
+ print("espeak-ng installation failed, trying espeak...")
103
  try:
104
+ # Fallback to espeak
105
+ subprocess.run(['sudo', 'apt-get', 'install', '-y', 'espeak'], check=True, capture_output=True)
106
+ print("espeak installed successfully.")
107
+ except (FileNotFoundError, subprocess.CalledProcessError) as espeak_err:
108
+ print(f"Warning: Could not install espeak-ng or espeak: {espeak_err}. TTS functionality will be disabled.")
109
+ return # Cannot proceed without espeak
110
+
111
  # Set up Kokoro TTS
112
+ if os.path.exists(KOKORO_PATH):
113
  import sys
114
+ if KOKORO_PATH not in sys.path:
115
+ sys.path.append(KOKORO_PATH)
116
+ try:
117
+ from models import build_model
118
+ from kokoro import generate as generate_tts_internal # Avoid name clash
119
+
120
+ # Make these functions accessible globally if needed, but better to keep scoped
121
+ globals()['build_model'] = build_model
122
+ globals()['generate_tts_internal'] = generate_tts_internal
123
+
124
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
125
+ print(f"Loading TTS model onto device: {device}")
126
+ # Ensure model path is correct
127
+ model_file = os.path.join(KOKORO_PATH, 'kokoro-v0_19.pth')
128
+ if not os.path.exists(model_file):
129
+ print(f"Error: TTS model file not found at {model_file}")
130
+ # Attempt to pull LFS files again
131
+ try:
132
+ print("Attempting git lfs pull...")
133
+ subprocess.run(['git', 'lfs', 'pull'], cwd=KOKORO_PATH, check=True, capture_output=True)
134
+ if not os.path.exists(model_file):
135
+ print(f"Error: TTS model file STILL not found at {model_file} after lfs pull.")
136
+ return
137
+ except Exception as lfs_pull_err:
138
+ print(f"Error during git lfs pull: {lfs_pull_err}")
139
+ return
140
+
141
+ TTS_MODEL = build_model(model_file, device)
142
+
143
+ # Preload default voice
144
+ default_voice_id = 'af'
145
+ voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{default_voice_id}.pt')
146
+ if os.path.exists(voice_file_path):
147
+ print(f"Loading default voice: {default_voice_id}")
148
+ VOICEPACKS[default_voice_id] = torch.load(voice_file_path,
149
+ map_location=device) # Removed weights_only=True
150
+ else:
151
+ print(f"Warning: Default voice file {voice_file_path} not found.")
152
+
153
+
154
+ # Preload other common voices to reduce latency
155
+ for voice_name, voice_id in VOICE_CHOICES.items():
156
+ if voice_id != default_voice_id: # Avoid reloading default
157
+ voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
158
+ if os.path.exists(voice_file_path):
159
+ try:
160
+ print(f"Preloading voice: {voice_id}")
161
+ VOICEPACKS[voice_id] = torch.load(voice_file_path,
162
+ map_location=device) # Removed weights_only=True
163
+ except Exception as e:
164
+ print(f"Warning: Could not preload voice {voice_id}: {str(e)}")
165
+ else:
166
+ print(f"Info: Voice file {voice_file_path} for '{voice_name}' not found, will skip preloading.")
167
+
168
+ TTS_ENABLED = True
169
+ print("TTS setup completed successfully")
170
+ except ImportError as ie:
171
+ print(f"Error importing Kokoro modules: {ie}. Check if Kokoro-82M is correctly cloned and in sys.path.")
172
+ except Exception as model_load_err:
173
+ print(f"Error loading TTS model or voices: {model_load_err}")
174
+
175
  else:
176
+ print(f"Warning: {KOKORO_PATH} directory not found after clone attempt. TTS disabled.")
177
+ except subprocess.CalledProcessError as spe:
178
+ print(f"Warning: A subprocess command failed during TTS setup: {spe}")
179
+ print(f"Command: {' '.join(spe.cmd)}")
180
+ print(f"Stderr: {spe.stderr}")
181
+ print("TTS may be disabled.")
182
  except Exception as e:
183
+ print(f"Warning: An unexpected error occurred during TTS setup: {str(e)}")
184
  TTS_ENABLED = False
185
 
186
  # Start TTS setup in a separate thread
187
+ print("Starting TTS setup in background thread...")
188
+ tts_thread = threading.Thread(target=setup_tts, daemon=True)
189
+ tts_thread.start()
190
 
191
  # --- Search and Generation Functions ---
192
  @lru_cache(maxsize=128)
193
  def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
194
  """Get web search results using DuckDuckGo with caching for improved performance"""
195
+ print(f"Performing web search for: '{query}'")
196
  try:
197
  with DDGS() as ddgs:
198
+ # Using safe='off' potentially gives more results but use cautiously
199
+ results = list(ddgs.text(query, max_results=max_results, safesearch='moderate'))
200
+ print(f"Found {len(results)} results.")
201
+ formatted_results = []
202
+ for result in results:
203
+ formatted_results.append({
204
+ "title": result.get("title", "No Title"),
205
+ "snippet": result.get("body", "No Snippet Available"),
206
+ "url": result.get("href", "#"),
207
+ # Attempt to extract date - DDGS doesn't reliably provide it
208
+ # "date": result.get("published", "") # Placeholder
209
+ })
210
+ return formatted_results
211
  except Exception as e:
212
  print(f"Error in web search: {e}")
213
  return []
 
215
  def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
216
  """Format the prompt with web context"""
217
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
218
+ context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for i, res in enumerate(context)]) # No need for index here
219
+ prompt = f"""You are a helpful AI assistant. Your task is to answer the user's query based *only* on the provided web search context.
220
+ Do not add information not present in the context.
221
+ Cite the sources used in your answer using bracket notation, e.g., [Source Title]. Use the titles from the context.
222
+ If the context does not contain relevant information to answer the query, state that clearly.
223
  Current Time: {current_time}
224
+
 
225
  Web Context:
226
+ {context_lines if context else "No web context available."}
227
+
228
+ User Query: {query}
229
+
230
  Answer:"""
231
+ # print(f"Formatted Prompt:\n{prompt}") # Debugging
232
+ return prompt
233
 
234
  def format_sources(web_results: List[Dict[str, str]]) -> str:
235
  """Format sources with more details"""
236
  if not web_results:
237
+ return "<div class='no-sources'>No sources found for the query.</div>"
238
 
239
  sources_html = "<div class='sources-container'>"
240
  for i, res in enumerate(web_results, 1):
241
+ title = res.get("title", "Source")
242
+ url = res.get("url", "#")
243
+ # date = f"<span class='source-date'>{res['date']}</span>" if res.get('date') else "" # DDG date is unreliable
244
+ snippet = res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else "")
245
  sources_html += f"""
246
  <div class='source-item'>
247
  <div class='source-number'>[{i}]</div>
248
  <div class='source-content'>
249
+ <a href="{url}" target="_blank" class='source-title' title="{url}">{title}</a>
 
250
  <div class='source-snippet'>{snippet}</div>
251
  </div>
252
  </div>
 
254
  sources_html += "</div>"
255
  return sources_html
256
 
257
+ # Use a ThreadPoolExecutor for potentially blocking I/O or CPU-bound tasks
258
+ # Keep GPU tasks separate if possible, or ensure thread safety if sharing GPU resources
259
+ executor = ThreadPoolExecutor(max_workers=4)
260
+
261
+ @spaces.GPU(duration=GPU_DURATION, cancellable=True)
262
+ async def generate_answer(prompt: str) -> str:
263
+ """Generate answer using the DeepSeek model with optimized settings (Async Wrapper)"""
264
+ print("Generating answer...")
265
+ try:
266
+ inputs = tokenizer(
267
+ prompt,
268
+ return_tensors="pt",
269
+ padding=True,
270
+ truncation=True,
271
+ max_length=1024, # Increased context length
272
+ return_attention_mask=True
273
+ ).to(model.device)
274
+
275
+ # Ensure generation runs on the correct device
276
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available() and torch_dtype == torch.float16):
277
+ outputs = await asyncio.to_thread( # Use asyncio.to_thread for potentially blocking calls
278
+ model.generate,
279
+ inputs.input_ids,
280
+ attention_mask=inputs.attention_mask,
281
+ max_new_tokens=MAX_NEW_TOKENS,
282
+ temperature=TEMPERATURE,
283
+ top_p=TOP_P,
284
+ pad_token_id=tokenizer.eos_token_id,
285
+ do_sample=True,
286
+ early_stopping=True,
287
+ num_return_sequences=1
288
+ )
289
+
290
+ # Decode output
291
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
292
+ # Extract only the generated part after "Answer:"
293
+ answer_part = full_output.split("Answer:")[-1].strip()
294
+ print(f"Generated Answer Raw Length: {len(outputs[0])}, Decoded Answer Part Length: {len(answer_part)}")
295
+ if not answer_part: # Handle cases where split might fail or answer is empty
296
+ print("Warning: Could not extract answer after 'Answer:'. Returning full output.")
297
+ return full_output # Fallback
298
+ return answer_part
299
+ except Exception as e:
300
+ print(f"Error during answer generation: {e}")
301
+ # You might want to return a specific error message here
302
+ return f"Error generating answer: {str(e)}"
303
+
304
+ # Ensure this function runs potentially long tasks in a thread using the executor
305
+ # @spaces.GPU(duration=GPU_DURATION, cancellable=True) # Keep GPU decorator if TTS uses GPU heavily
306
+ async def generate_speech(text: str, voice_id: str = 'af') -> Tuple[int, np.ndarray] | None:
307
+ """Generate speech from text using Kokoro TTS model (Async Wrapper)."""
308
+ global TTS_MODEL, TTS_ENABLED, VOICEPACKS
309
+ print(f"Attempting to generate speech for text (length {len(text)}) with voice '{voice_id}'")
310
 
 
 
 
 
 
311
  if not TTS_ENABLED or TTS_MODEL is None:
312
+ print("TTS is not enabled or model not loaded.")
313
+ return None
314
+ if 'generate_tts_internal' not in globals():
315
+ print("TTS generation function 'generate_tts_internal' not found.")
316
  return None
317
 
318
  try:
319
+ device = TTS_MODEL.device # Get device from the loaded TTS model
320
+
321
+ # Load voicepack if needed (handle potential errors)
322
+ if voice_id not in VOICEPACKS:
323
+ voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
324
+ if os.path.exists(voice_file_path):
325
+ print(f"Loading voice '{voice_id}' on demand...")
326
+ try:
327
+ VOICEPACKS[voice_id] = await asyncio.to_thread(
328
+ torch.load, voice_file_path, map_location=device # Removed weights_only=True
329
+ )
330
+ except Exception as load_err:
331
+ print(f"Error loading voicepack {voice_id}: {load_err}. Falling back to default 'af'.")
332
+ voice_id = 'af' # Fallback to default
333
+ # Ensure default is loaded if fallback occurs
334
+ if 'af' not in VOICEPACKS:
335
+ default_voice_file = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
336
+ if os.path.exists(default_voice_file):
337
+ VOICEPACKS['af'] = await asyncio.to_thread(
338
+ torch.load, default_voice_file, map_location=device
339
+ )
340
+ else:
341
+ print("Default voice 'af' also not found. Cannot generate audio.")
342
+ return None
343
  else:
344
+ print(f"Voicepack {voice_id}.pt not found. Falling back to default 'af'.")
345
+ voice_id = 'af' # Fallback to default
346
+ if 'af' not in VOICEPACKS: # Check again if default is needed now
347
+ default_voice_file = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
348
+ if os.path.exists(default_voice_file):
349
+ VOICEPACKS['af'] = await asyncio.to_thread(
350
+ torch.load, default_voice_file, map_location=device
351
+ )
352
+ else:
353
+ print("Default voice 'af' also not found. Cannot generate audio.")
354
+ return None
355
+
356
+ if voice_id not in VOICEPACKS:
357
+ print(f"Error: Voice '{voice_id}' could not be loaded.")
358
+ return None
359
+
360
+ # Clean the text (simple cleaning)
361
+ clean_text = ' '.join(text.split()) # Remove extra whitespace
362
+ clean_text = clean_text.replace('*', '').replace('[', '').replace(']', '') # Remove markdown chars
363
+
364
+ # Ensure text isn't empty
365
+ if not clean_text.strip():
366
+ print("Warning: Empty text provided for TTS.")
367
+ return None
368
+
369
+ # Limit text length
370
+ if len(clean_text) > MAX_TTS_CHARS:
371
+ print(f"Warning: Text too long ({len(clean_text)} chars), truncating to {MAX_TTS_CHARS}.")
372
+ # Simple truncation, could be smarter (split by sentence)
373
+ clean_text = clean_text[:MAX_TTS_CHARS]
374
+ last_space = clean_text.rfind(' ')
375
+ if last_space != -1:
376
+ clean_text = clean_text[:last_space] + "..." # Truncate at last space
377
+
378
+ # Run the potentially blocking TTS generation in a thread
379
+ print(f"Generating audio for: '{clean_text[:100]}...'")
380
+ gen_func = globals()['generate_tts_internal']
381
+ loop = asyncio.get_event_loop()
382
+ audio_data, _ = await loop.run_in_executor(
383
+ executor, # Use the thread pool executor
384
+ gen_func,
385
+ TTS_MODEL,
386
+ clean_text,
387
+ VOICEPACKS[voice_id],
388
+ 'a' # Language code (assuming 'a' is appropriate)
389
+ )
390
+
391
+ if isinstance(audio_data, torch.Tensor):
392
+ # Move tensor to CPU before converting to numpy if it's not already
393
+ audio_np = audio_data.cpu().numpy()
394
+ elif isinstance(audio_data, np.ndarray):
395
+ audio_np = audio_data
396
  else:
397
+ print("Warning: Unexpected audio data type from TTS.")
398
+ return None
399
+
400
+ print(f"Audio generated successfully, shape: {audio_np.shape}")
401
+ return (TTS_SAMPLE_RATE, audio_np)
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  except Exception as e:
404
+ import traceback
405
  print(f"Error generating speech: {str(e)}")
406
+ print(traceback.format_exc()) # Print full traceback for debugging
407
  return None
408
 
409
+ # Helper to get voice ID from display name
410
+ def get_voice_id(voice_display_name: str) -> str:
411
+ """Maps the user-friendly voice name to the internal voice ID."""
412
+ return VOICE_CHOICES.get(voice_display_name, 'af') # Default to 'af' if not found
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
+ # --- Main Processing Logic (Async) ---
415
+ async def process_query_async(query: str, history: List[List[str]], selected_voice_display_name: str):
416
+ """Asynchronously process user query: search -> generate answer -> generate speech"""
417
+ if not query:
 
418
  yield (
419
+ "Please enter a query.", "", "Search", history, None
 
 
 
 
420
  )
421
+ return
422
+
423
+ if history is None: history = []
424
+ current_history = history + [[query, "*Searching...*"]]
425
+
426
+ # 1. Initial state: Searching
427
+ yield (
428
+ "*Searching & Thinking...*",
429
+ "<div class='searching'>Searching the web...</div>",
430
+ gr.Button(value="Searching...", interactive=False), # Disable button
431
+ current_history,
432
+ None
433
+ )
434
 
435
+ # 2. Perform Web Search (non-blocking)
436
+ loop = asyncio.get_event_loop()
437
+ web_results = await loop.run_in_executor(executor, get_web_results, query)
438
+ sources_html = format_sources(web_results)
439
+
440
+ # Update state: Analyzing results
441
+ current_history[-1][1] = "*Analyzing search results...*"
442
+ yield (
443
+ "*Analyzing search results...*",
444
+ sources_html,
445
+ gr.Button(value="Generating...", interactive=False),
446
+ current_history,
447
+ None
448
+ )
 
 
449
 
450
+ # 3. Generate Answer (non-blocking, potentially on GPU)
451
+ prompt = format_prompt(query, web_results)
452
+ final_answer = await generate_answer(prompt) # Already async
453
+
454
+ # Update state: Answer generated
455
+ current_history[-1][1] = final_answer
456
+ yield (
457
+ final_answer,
458
+ sources_html,
459
+ gr.Button(value="Audio...", interactive=False),
460
+ current_history,
461
+ None
462
+ )
463
 
464
+ # 4. Generate Speech (non-blocking, potentially on GPU)
465
+ audio = None
466
+ tts_message = ""
467
+ if not tts_thread.is_alive() and not TTS_ENABLED:
468
+ tts_message = "\n\n*(TTS setup failed or is disabled)*"
469
+ elif tts_thread.is_alive():
470
+ tts_message = "\n\n*(TTS is still initializing, audio may be delayed)*"
471
+ elif TTS_ENABLED:
472
+ voice_id = get_voice_id(selected_voice_display_name)
473
+ audio = await generate_speech(final_answer, voice_id) # Already async
474
+ if audio is None:
475
+ tts_message = f"\n\n*(Audio generation failed for voice '{voice_id}')*"
476
+
477
+ # 5. Final state: Show everything
478
+ yield (
479
+ final_answer + tts_message,
480
+ sources_html,
481
+ gr.Button(value="Search", interactive=True), # Re-enable button
482
+ current_history,
483
+ audio
484
+ )
485
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
+ # --- Gradio Interface ---
488
  css = """
489
+ /* ... [Your existing CSS remains unchanged] ... */
490
+ .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
491
+ #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
492
+ #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
493
+ #header h3 { color: #a8a9ab; }
494
+ .search-container { background: #ffffff; border: 1px solid #e0e0e0; border-radius: 12px; box-shadow: 0 4px 16px rgba(0,0,0,0.05); padding: 1.5rem; margin-bottom: 1.5rem; }
495
+ .search-box { padding: 0; margin-bottom: 1rem; }
496
+ .search-box .gradio-textbox { border-radius: 8px 0 0 8px !important; } /* Style textbox specifically */
497
+ .search-box .gradio-dropdown { border-radius: 0 !important; margin-left: -1px; margin-right: -1px;} /* Style dropdown */
498
+ .search-box .gradio-button { border-radius: 0 8px 8px 0 !important; } /* Style button */
499
+ .search-box input[type="text"] { background: #f7f7f8 !important; border: 1px solid #d1d5db !important; color: #1f2937 !important; transition: all 0.3s ease; height: 42px !important; }
500
+ .search-box input[type="text"]:focus { border-color: #2563eb !important; box-shadow: 0 0 0 2px rgba(37, 99, 235, 0.2) !important; background: white !important; }
501
+ .search-box input[type="text"]::placeholder { color: #9ca3af !important; }
502
+ .search-box button { background: #2563eb !important; border: none !important; color: white !important; box-shadow: 0 1px 2px rgba(0,0,0,0.05) !important; transition: all 0.3s ease !important; height: 44px !important; }
503
+ .search-box button:hover { background: #1d4ed8 !important; }
504
+ .search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
505
+ .results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
506
+ .answer-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; color: #1f2937; margin-bottom: 1.5rem; box-shadow: 0 2px 8px rgba(0,0,0,0.05); }
507
+ .answer-box p { color: #374151; line-height: 1.7; }
508
+ .answer-box code { background: #f3f4f6; border-radius: 4px; padding: 2px 4px; color: #4b5563; font-size: 0.9em; }
509
+ .sources-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; }
510
+ .sources-box h3 { margin-top: 0; margin-bottom: 1rem; color: #111827; font-size: 1.2rem; }
511
+ .sources-container { margin-top: 0; }
512
+ .source-item { display: flex; padding: 10px 0; margin: 0; border-bottom: 1px solid #f3f4f6; transition: background-color 0.2s; }
513
+ .source-item:last-child { border-bottom: none; }
514
+ /* .source-item:hover { background-color: #f9fafb; } */
515
+ .source-number { font-weight: bold; margin-right: 12px; color: #6b7280; width: 20px; text-align: right; flex-shrink: 0;}
516
+ .source-content { flex: 1; }
517
+ .source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; }
518
+ .source-title:hover { color: #1d4ed8; text-decoration: underline; }
519
+ .source-date { color: #6b7280; font-size: 0.8em; margin-left: 8px; }
520
+ .source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
521
+ .chat-history { max-height: 400px; overflow-y: auto; padding: 1rem; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; margin-top: 1rem; scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
522
+ .chat-history::-webkit-scrollbar { width: 6px; }
523
+ .chat-history::-webkit-scrollbar-track { background: #f9fafb; }
524
+ .chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
525
+ .examples-container { background: #f9fafb; border-radius: 8px; padding: 1rem; margin-top: 1rem; border: 1px solid #e5e7eb; }
526
+ .examples-container .gradio-examples { gap: 8px !important; } /* Target examples component */
527
+ .examples-container button { background: white !important; border: 1px solid #d1d5db !important; color: #374151 !important; transition: all 0.2s; margin: 0 !important; font-size: 0.9em !important; padding: 6px 12px !important; }
528
+ .examples-container button:hover { background: #f3f4f6 !important; border-color: #adb5bd !important; }
529
+ .markdown-content { color: #374151 !important; font-size: 1rem; line-height: 1.7; }
530
+ .markdown-content h1, .markdown-content h2, .markdown-content h3 { color: #111827 !important; margin-top: 1.2em !important; margin-bottom: 0.6em !important; font-weight: 600; }
531
+ .markdown-content h1 { font-size: 1.6em !important; border-bottom: 1px solid #e5e7eb; padding-bottom: 0.3em; }
532
+ .markdown-content h2 { font-size: 1.4em !important; border-bottom: 1px solid #e5e7eb; padding-bottom: 0.3em;}
533
+ .markdown-content h3 { font-size: 1.2em !important; }
534
+ .markdown-content a { color: #2563eb !important; text-decoration: none !important; transition: all 0.2s; }
535
+ .markdown-content a:hover { color: #1d4ed8 !important; text-decoration: underline !important; }
536
+ .markdown-content code { background: #f3f4f6 !important; padding: 2px 6px !important; border-radius: 4px !important; font-family: monospace !important; color: #4b5563; font-size: 0.9em; }
537
+ .markdown-content pre { background: #f3f4f6 !important; padding: 12px !important; border-radius: 8px !important; overflow-x: auto !important; border: 1px solid #e5e7eb;}
538
+ .markdown-content pre code { background: transparent !important; padding: 0 !important; border: none !important; font-size: 0.9em;}
539
+ .markdown-content blockquote { border-left: 4px solid #d1d5db !important; padding-left: 1em !important; margin-left: 0 !important; color: #6b7280 !important; }
540
+ .markdown-content table { border-collapse: collapse !important; width: 100% !important; margin: 1em 0; }
541
+ .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
542
+ .markdown-content th { background: #f9fafb !important; font-weight: 600; }
543
+ .accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; }
544
+ .accordion > .label-wrap { padding: 10px 15px !important; } /* Style accordion header */
545
+ .voice-selector { margin: 0; padding: 0; }
546
+ .voice-selector div[data-testid="dropdown"] { /* Target the specific dropdown container */ height: 44px !important; }
547
+ .voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
548
+ .voice-selector select:focus { border-color: #2563eb !important; box-shadow: none !important; }
549
+ .audio-player { margin-top: 1rem; background: #f9fafb !important; border-radius: 8px !important; padding: 0.5rem !important; border: 1px solid #e5e7eb;}
550
+ .audio-player audio { width: 100% !important; }
551
+ .searching, .error { padding: 1rem; border-radius: 8px; text-align: center; margin: 1rem 0; border: 1px dashed; }
552
+ .searching { background: #eff6ff; color: #3b82f6; border-color: #bfdbfe; }
553
+ .error { background: #fef2f2; color: #ef4444; border-color: #fecaca; }
554
+ .no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
555
+ @keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
556
+ .searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; } /* Add span for animation */
557
+ .dark .gradio-container { background-color: #111827 !important; }
558
+ .dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
559
+ .dark #header h3 { color: #9ca3af; }
560
+ .dark .search-container { background: #1f2937; border-color: #374151; }
561
+ .dark .search-box input[type="text"] { background: #374151 !important; border-color: #4b5563 !important; color: #e5e7eb !important; }
562
+ .dark .search-box input[type="text"]:focus { border-color: #3b82f6 !important; background: #4b5563 !important; box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.3) !important; }
563
+ .dark .search-box input[type="text"]::placeholder { color: #9ca3af !important; }
564
+ .dark .search-box button { background: #3b82f6 !important; }
565
+ .dark .search-box button:hover { background: #2563eb !important; }
566
+ .dark .search-box button:disabled { background: #4b5563 !important; }
567
+ .dark .answer-box { background: #1f2937; border-color: #374151; color: #e5e7eb; }
568
+ .dark .answer-box p { color: #d1d5db; }
569
+ .dark .answer-box code { background: #374151; color: #9ca3af; }
570
+ .dark .sources-box { background: #1f2937; border-color: #374151; }
571
+ .dark .sources-box h3 { color: #f9fafb; }
572
+ .dark .source-item { border-bottom-color: #374151; }
573
+ .dark .source-item:hover { background-color: #374151; }
574
+ .dark .source-number { color: #9ca3af; }
575
+ .dark .source-title { color: #60a5fa; }
576
+ .dark .source-title:hover { color: #93c5fd; }
577
+ .dark .source-snippet { color: #d1d5db; }
578
+ .dark .chat-history { background: #374151; border-color: #4b5563; scrollbar-color: #4b5563 #374151; }
579
+ .dark .chat-history::-webkit-scrollbar-track { background: #374151; }
580
+ .dark .chat-history::-webkit-scrollbar-thumb { background-color: #4b5563; }
581
+ .dark .examples-container { background: #374151; border-color: #4b5563; }
582
+ .dark .examples-container button { background: #1f2937 !important; border-color: #4b5563 !important; color: #d1d5db !important; }
583
+ .dark .examples-container button:hover { background: #4b5563 !important; border-color: #6b7280 !important; }
584
+ .dark .markdown-content { color: #d1d5db !important; }
585
+ .dark .markdown-content h1, .dark .markdown-content h2, .dark .markdown-content h3 { color: #f9fafb !important; border-bottom-color: #4b5563; }
586
+ .dark .markdown-content a { color: #60a5fa !important; }
587
+ .dark .markdown-content a:hover { color: #93c5fd !important; }
588
+ .dark .markdown-content code { background: #374151 !important; color: #9ca3af; }
589
+ .dark .markdown-content pre { background: #374151 !important; border-color: #4b5563;}
590
+ .dark .markdown-content pre code { background: transparent !important; }
591
+ .dark .markdown-content blockquote { border-left-color: #4b5563 !important; color: #9ca3af !important; }
592
+ .dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
593
+ .dark .markdown-content th { background: #374151 !important; }
594
+ .dark .accordion { background: #374151 !important; border-color: #4b5563 !important; }
595
+ .dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
596
+ .dark .voice-selector select:focus { border-color: #3b82f6 !important; }
597
+ .dark .audio-player { background: #374151 !important; border-color: #4b5563;}
598
+ .dark .searching { background: #1e3a8a; color: #93c5fd; border-color: #3b82f6; }
599
+ .dark .error { background: #7f1d1d; color: #fca5a5; border-color: #ef4444; }
600
+ .dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}
601
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  """
603
 
604
+ with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
 
605
  chat_history = gr.State([])
606
+
607
+ with gr.Column(): # Main container
608
+ with gr.Column(elem_id="header"):
609
+ gr.Markdown("# 🔍 AI Search Assistant")
610
+ gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
611
+
612
+ with gr.Column(elem_classes="search-container"):
613
+ with gr.Row(elem_classes="search-box", equal_height=True):
614
+ search_input = gr.Textbox(
615
+ label="",
616
+ placeholder="Ask anything...",
617
+ scale=5,
618
+ container=False, # Important for direct styling
619
+ elem_classes="gradio-textbox"
620
+ )
621
+ voice_select = gr.Dropdown(
622
+ choices=list(VOICE_CHOICES.keys()),
623
+ value=list(VOICE_CHOICES.keys())[0],
624
+ label="", # No label needed here
625
+ scale=2,
626
+ container=False, # Important
627
+ elem_classes="voice-selector gradio-dropdown"
628
+ )
629
+ search_btn = gr.Button(
630
+ "Search",
631
+ variant="primary",
632
+ scale=1,
633
+ elem_classes="gradio-button"
634
+ )
635
+
636
+ with gr.Row(elem_classes="results-container", equal_height=False):
637
+ with gr.Column(scale=3): # Wider column for answer + history
638
+ with gr.Column(elem_classes="answer-box"):
639
+ answer_output = gr.Markdown(elem_classes="markdown-content", value="*Your answer will appear here...*")
640
+ # Audio player below the answer
641
+ audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player", type="numpy") # Expect numpy array
642
+
643
+ with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
644
+ chat_history_display = gr.Chatbot(elem_classes="chat-history", label="History", height=300)
645
+
646
+ with gr.Column(scale=2): # Narrower column for sources
647
+ with gr.Column(elem_classes="sources-box"):
648
+ gr.Markdown("### Sources")
649
+ sources_output = gr.HTML(value="<div class='no-sources'>Sources will appear here after searching.</div>")
650
+
651
+ with gr.Row(elem_classes="examples-container"):
652
+ gr.Examples(
653
+ examples=[
654
+ "Latest news about renewable energy",
655
+ "Explain the concept of Large Language Models (LLMs)",
656
+ "What are the symptoms and prevention tips for the flu?",
657
+ "Compare Python and JavaScript for web development"
658
+ ],
659
+ inputs=search_input,
660
+ label="Try these examples:",
661
+ elem_classes="gradio-examples" # Add class for potential styling
662
+ )
663
+
664
+ # --- Event Handling ---
665
+ # Use the async function for processing
666
+ async def handle_interaction(query, history, voice_display_name):
667
+ """Wrapper to handle the async generator from process_query_async"""
668
+ try:
669
+ async for update in process_query_async(query, history, voice_display_name):
670
+ # Ensure the button state is updated correctly
671
+ ans_out, src_out, btn_state, hist_display, aud_out = update
672
+ yield ans_out, src_out, btn_state, hist_display, aud_out
673
+ except Exception as e:
674
+ print(f"Error in handle_interaction: {e}")
675
+ import traceback
676
+ traceback.print_exc()
677
+ error_message = f"An unexpected error occurred: {e}"
678
+ # Provide a final error state update
679
+ yield (
680
+ error_message,
681
+ "<div class='error'>Error processing request.</div>",
682
+ gr.Button(value="Search", interactive=True), # Re-enable button on error
683
+ history + [[query, f"*Error: {error_message}*"]],
684
+ None
685
  )
686
 
 
 
 
687
 
688
+ # Corrected event listeners: Pass the voice_select component directly
689
  search_btn.click(
690
+ fn=handle_interaction,
691
+ inputs=[search_input, chat_history, voice_select], # Pass voice_select component
692
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
693
  )
694
+
 
695
  search_input.submit(
696
+ fn=handle_interaction,
697
+ inputs=[search_input, chat_history, voice_select], # Pass voice_select component
698
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
699
  )
700
 
701
  if __name__ == "__main__":
702
+ # Launch the app
703
+ demo.queue(max_size=20).launch(debug=True, share=True) # Enable debug for more logs