File size: 38,280 Bytes
cf40b67
a6e4f9f
 
cf40b67
a6e4f9f
 
cf40b67
a6e4f9f
 
 
 
b8c63a2
 
 
 
3d63694
 
 
 
b8c63a2
 
 
 
 
3d63694
 
b8c63a2
 
 
 
 
 
a6e4f9f
b8c63a2
 
 
3d63694
b8c63a2
3d63694
 
 
 
b8c63a2
 
3d63694
 
b8c63a2
3d63694
b8c63a2
3d63694
b8c63a2
a6e4f9f
b8c63a2
3d63694
 
b8c63a2
 
 
a6e4f9f
 
 
 
 
 
 
 
b8c63a2
3d63694
d64ad42
b8c63a2
 
 
3d63694
b8c63a2
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
3d63694
 
 
 
 
 
b8c63a2
3d63694
 
 
 
 
 
 
b8c63a2
3d63694
b8c63a2
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8c63a2
3d63694
 
 
 
 
 
b8c63a2
3d63694
b8c63a2
60c475d
b8c63a2
3d63694
 
 
b8c63a2
 
 
 
 
3d63694
a6e4f9f
 
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
 
 
 
 
 
 
3d63694
 
 
 
 
a6e4f9f
3d63694
a6e4f9f
3d63694
 
 
 
a6e4f9f
3d63694
 
d64ad42
a6e4f9f
 
cf40b67
3d63694
a6e4f9f
 
cf40b67
3d63694
 
 
 
cf40b67
a6e4f9f
 
 
3d63694
b8c63a2
a6e4f9f
cf40b67
 
 
 
 
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
b8c63a2
3d63694
 
 
 
a6e4f9f
 
 
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
3d63694
 
 
 
 
a6e4f9f
 
3d63694
a6e4f9f
3d63694
a6e4f9f
 
3d63694
 
 
 
d64ad42
3d63694
 
 
 
b8c63a2
3d63694
b8c63a2
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
b8c63a2
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
 
3d63694
cf40b67
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf40b67
 
3d63694
a6e4f9f
3d63694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e4f9f
cf40b67
b8c63a2
3d63694
cf40b67
3d63694
 
a6e4f9f
cf40b67
3d63694
cf40b67
3d63694
 
a6e4f9f
cf40b67
 
 
3d63694
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from duckduckgo_search import DDGS
import time
import torch
from datetime import datetime
import os
import subprocess
import numpy as np
from typing import List, Dict, Tuple, Any
from functools import lru_cache
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
import warnings

# Suppress specific warnings if needed (optional)
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")

# --- Configuration ---
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
MAX_SEARCH_RESULTS = 5
TTS_SAMPLE_RATE = 24000
MAX_TTS_CHARS = 1000 # Reduced for faster testing, adjust as needed
GPU_DURATION = 60 # Increased duration for longer tasks like TTS
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95

# --- Initialization ---
# Initialize model and tokenizer with better error handling
try:
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token

    print("Loading model...")
    # Determine device map based on CUDA availability
    device_map = "auto" if torch.cuda.is_available() else {"": "cpu"}
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Use float32 on CPU

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map=device_map,
        # offload_folder="offload", # Only use offload if really needed and configured
        low_cpu_mem_usage=True,
        torch_dtype=torch_dtype
    )
    print(f"Model loaded on device map: {model.hf_device_map}")
    print("Model and tokenizer loaded successfully")
except Exception as e:
    print(f"Error initializing model: {str(e)}")
    # If running in Spaces, maybe try loading to CPU as fallback?
    # For now, just raise the error.
    raise

# --- TTS Setup ---
VOICE_CHOICES = {
    'πŸ‡ΊπŸ‡Έ Female (Default)': 'af',
    'πŸ‡ΊπŸ‡Έ Bella': 'af_bella',
    'πŸ‡ΊπŸ‡Έ Sarah': 'af_sarah',
    'πŸ‡ΊπŸ‡Έ Nicole': 'af_nicole'
}
TTS_ENABLED = False
TTS_MODEL = None
VOICEPACKS = {}  # Cache voice packs
KOKORO_PATH = 'Kokoro-82M'

# Initialize Kokoro TTS in a separate thread to avoid blocking startup
def setup_tts():
    global TTS_ENABLED, TTS_MODEL, VOICEPACKS

    try:
        # Check if Kokoro already exists
        if not os.path.exists(KOKORO_PATH):
            print("Cloning Kokoro-82M repository...")
            # Install git-lfs if not present (might need sudo/apt)
            try:
                subprocess.run(['git', 'lfs', 'install'], check=True, capture_output=True)
            except (FileNotFoundError, subprocess.CalledProcessError) as lfs_err:
                print(f"Warning: git-lfs might not be installed or failed: {lfs_err}. Cloning might be slow or incomplete.")

            clone_cmd = ['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M']
            result = subprocess.run(clone_cmd, check=True, capture_output=True, text=True)
            print("Kokoro cloned successfully.")
            print(result.stdout)
            # Optionally pull LFS files if needed (sometimes clone doesn't get them all)
            # subprocess.run(['git', 'lfs', 'pull'], cwd=KOKORO_PATH, check=True)

        else:
            print("Kokoro-82M directory already exists.")

        # Install espeak (essential for phonemization)
        print("Attempting to install espeak-ng or espeak...")
        try:
            # Try installing espeak-ng first (often preferred)
            subprocess.run(['sudo', 'apt-get', 'update'], check=True, capture_output=True)
            subprocess.run(['sudo', 'apt-get', 'install', '-y', 'espeak-ng'], check=True, capture_output=True)
            print("espeak-ng installed successfully.")
        except (FileNotFoundError, subprocess.CalledProcessError):
            print("espeak-ng installation failed, trying espeak...")
            try:
                # Fallback to espeak
                subprocess.run(['sudo', 'apt-get', 'install', '-y', 'espeak'], check=True, capture_output=True)
                print("espeak installed successfully.")
            except (FileNotFoundError, subprocess.CalledProcessError) as espeak_err:
                print(f"Warning: Could not install espeak-ng or espeak: {espeak_err}. TTS functionality will be disabled.")
                return # Cannot proceed without espeak

        # Set up Kokoro TTS
        if os.path.exists(KOKORO_PATH):
            import sys
            if KOKORO_PATH not in sys.path:
                sys.path.append(KOKORO_PATH)
            try:
                from models import build_model
                from kokoro import generate as generate_tts_internal # Avoid name clash

                # Make these functions accessible globally if needed, but better to keep scoped
                globals()['build_model'] = build_model
                globals()['generate_tts_internal'] = generate_tts_internal

                device = 'cuda' if torch.cuda.is_available() else 'cpu'
                print(f"Loading TTS model onto device: {device}")
                # Ensure model path is correct
                model_file = os.path.join(KOKORO_PATH, 'kokoro-v0_19.pth')
                if not os.path.exists(model_file):
                     print(f"Error: TTS model file not found at {model_file}")
                     # Attempt to pull LFS files again
                     try:
                         print("Attempting git lfs pull...")
                         subprocess.run(['git', 'lfs', 'pull'], cwd=KOKORO_PATH, check=True, capture_output=True)
                         if not os.path.exists(model_file):
                            print(f"Error: TTS model file STILL not found at {model_file} after lfs pull.")
                            return
                     except Exception as lfs_pull_err:
                         print(f"Error during git lfs pull: {lfs_pull_err}")
                         return

                TTS_MODEL = build_model(model_file, device)

                # Preload default voice
                default_voice_id = 'af'
                voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{default_voice_id}.pt')
                if os.path.exists(voice_file_path):
                    print(f"Loading default voice: {default_voice_id}")
                    VOICEPACKS[default_voice_id] = torch.load(voice_file_path,
                                                       map_location=device) # Removed weights_only=True
                else:
                    print(f"Warning: Default voice file {voice_file_path} not found.")


                # Preload other common voices to reduce latency
                for voice_name, voice_id in VOICE_CHOICES.items():
                    if voice_id != default_voice_id: # Avoid reloading default
                        voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
                        if os.path.exists(voice_file_path):
                            try:
                                print(f"Preloading voice: {voice_id}")
                                VOICEPACKS[voice_id] = torch.load(voice_file_path,
                                                               map_location=device) # Removed weights_only=True
                            except Exception as e:
                                print(f"Warning: Could not preload voice {voice_id}: {str(e)}")
                        else:
                            print(f"Info: Voice file {voice_file_path} for '{voice_name}' not found, will skip preloading.")

                TTS_ENABLED = True
                print("TTS setup completed successfully")
            except ImportError as ie:
                print(f"Error importing Kokoro modules: {ie}. Check if Kokoro-82M is correctly cloned and in sys.path.")
            except Exception as model_load_err:
                print(f"Error loading TTS model or voices: {model_load_err}")

        else:
            print(f"Warning: {KOKORO_PATH} directory not found after clone attempt. TTS disabled.")
    except subprocess.CalledProcessError as spe:
        print(f"Warning: A subprocess command failed during TTS setup: {spe}")
        print(f"Command: {' '.join(spe.cmd)}")
        print(f"Stderr: {spe.stderr}")
        print("TTS may be disabled.")
    except Exception as e:
        print(f"Warning: An unexpected error occurred during TTS setup: {str(e)}")
        TTS_ENABLED = False

# Start TTS setup in a separate thread
print("Starting TTS setup in background thread...")
tts_thread = threading.Thread(target=setup_tts, daemon=True)
tts_thread.start()

# --- Search and Generation Functions ---
@lru_cache(maxsize=128)
def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
    """Get web search results using DuckDuckGo with caching for improved performance"""
    print(f"Performing web search for: '{query}'")
    try:
        with DDGS() as ddgs:
            # Using safe='off' potentially gives more results but use cautiously
            results = list(ddgs.text(query, max_results=max_results, safesearch='moderate'))
            print(f"Found {len(results)} results.")
            formatted_results = []
            for result in results:
                formatted_results.append({
                    "title": result.get("title", "No Title"),
                    "snippet": result.get("body", "No Snippet Available"),
                    "url": result.get("href", "#"),
                    # Attempt to extract date - DDGS doesn't reliably provide it
                    # "date": result.get("published", "") # Placeholder
                })
            return formatted_results
    except Exception as e:
        print(f"Error in web search: {e}")
        return []

def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
    """Format the prompt with web context"""
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for i, res in enumerate(context)]) # No need for index here
    prompt = f"""You are a helpful AI assistant. Your task is to answer the user's query based *only* on the provided web search context.
Do not add information not present in the context.
Cite the sources used in your answer using bracket notation, e.g., [Source Title]. Use the titles from the context.
If the context does not contain relevant information to answer the query, state that clearly.
Current Time: {current_time}

Web Context:
{context_lines if context else "No web context available."}

User Query: {query}

Answer:"""
    # print(f"Formatted Prompt:\n{prompt}") # Debugging
    return prompt

def format_sources(web_results: List[Dict[str, str]]) -> str:
    """Format sources with more details"""
    if not web_results:
        return "<div class='no-sources'>No sources found for the query.</div>"

    sources_html = "<div class='sources-container'>"
    for i, res in enumerate(web_results, 1):
        title = res.get("title", "Source")
        url = res.get("url", "#")
        # date = f"<span class='source-date'>{res['date']}</span>" if res.get('date') else "" # DDG date is unreliable
        snippet = res.get("snippet", "")[:150] + ("..." if len(res.get("snippet", "")) > 150 else "")
        sources_html += f"""
        <div class='source-item'>
            <div class='source-number'>[{i}]</div>
            <div class='source-content'>
                <a href="{url}" target="_blank" class='source-title' title="{url}">{title}</a>
                <div class='source-snippet'>{snippet}</div>
            </div>
        </div>
        """
    sources_html += "</div>"
    return sources_html

# Use a ThreadPoolExecutor for potentially blocking I/O or CPU-bound tasks
# Keep GPU tasks separate if possible, or ensure thread safety if sharing GPU resources
executor = ThreadPoolExecutor(max_workers=4)

@spaces.GPU(duration=GPU_DURATION, cancellable=True)
async def generate_answer(prompt: str) -> str:
    """Generate answer using the DeepSeek model with optimized settings (Async Wrapper)"""
    print("Generating answer...")
    try:
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024, # Increased context length
            return_attention_mask=True
        ).to(model.device)

        # Ensure generation runs on the correct device
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available() and torch_dtype == torch.float16):
            outputs = await asyncio.to_thread( # Use asyncio.to_thread for potentially blocking calls
                model.generate,
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                pad_token_id=tokenizer.eos_token_id,
                do_sample=True,
                early_stopping=True,
                num_return_sequences=1
            )

        # Decode output
        full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract only the generated part after "Answer:"
        answer_part = full_output.split("Answer:")[-1].strip()
        print(f"Generated Answer Raw Length: {len(outputs[0])}, Decoded Answer Part Length: {len(answer_part)}")
        if not answer_part: # Handle cases where split might fail or answer is empty
            print("Warning: Could not extract answer after 'Answer:'. Returning full output.")
            return full_output # Fallback
        return answer_part
    except Exception as e:
        print(f"Error during answer generation: {e}")
        # You might want to return a specific error message here
        return f"Error generating answer: {str(e)}"

# Ensure this function runs potentially long tasks in a thread using the executor
# @spaces.GPU(duration=GPU_DURATION, cancellable=True) # Keep GPU decorator if TTS uses GPU heavily
async def generate_speech(text: str, voice_id: str = 'af') -> Tuple[int, np.ndarray] | None:
    """Generate speech from text using Kokoro TTS model (Async Wrapper)."""
    global TTS_MODEL, TTS_ENABLED, VOICEPACKS
    print(f"Attempting to generate speech for text (length {len(text)}) with voice '{voice_id}'")

    if not TTS_ENABLED or TTS_MODEL is None:
        print("TTS is not enabled or model not loaded.")
        return None
    if 'generate_tts_internal' not in globals():
        print("TTS generation function 'generate_tts_internal' not found.")
        return None

    try:
        device = TTS_MODEL.device # Get device from the loaded TTS model

        # Load voicepack if needed (handle potential errors)
        if voice_id not in VOICEPACKS:
            voice_file_path = os.path.join(KOKORO_PATH, 'voices', f'{voice_id}.pt')
            if os.path.exists(voice_file_path):
                print(f"Loading voice '{voice_id}' on demand...")
                try:
                     VOICEPACKS[voice_id] = await asyncio.to_thread(
                         torch.load, voice_file_path, map_location=device # Removed weights_only=True
                     )
                except Exception as load_err:
                    print(f"Error loading voicepack {voice_id}: {load_err}. Falling back to default 'af'.")
                    voice_id = 'af' # Fallback to default
                    # Ensure default is loaded if fallback occurs
                    if 'af' not in VOICEPACKS:
                         default_voice_file = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
                         if os.path.exists(default_voice_file):
                              VOICEPACKS['af'] = await asyncio.to_thread(
                                  torch.load, default_voice_file, map_location=device
                              )
                         else:
                              print("Default voice 'af' also not found. Cannot generate audio.")
                              return None
            else:
                print(f"Voicepack {voice_id}.pt not found. Falling back to default 'af'.")
                voice_id = 'af' # Fallback to default
                if 'af' not in VOICEPACKS: # Check again if default is needed now
                     default_voice_file = os.path.join(KOKORO_PATH, 'voices', 'af.pt')
                     if os.path.exists(default_voice_file):
                          VOICEPACKS['af'] = await asyncio.to_thread(
                              torch.load, default_voice_file, map_location=device
                          )
                     else:
                          print("Default voice 'af' also not found. Cannot generate audio.")
                          return None

        if voice_id not in VOICEPACKS:
            print(f"Error: Voice '{voice_id}' could not be loaded.")
            return None

        # Clean the text (simple cleaning)
        clean_text = ' '.join(text.split()) # Remove extra whitespace
        clean_text = clean_text.replace('*', '').replace('[', '').replace(']', '') # Remove markdown chars

        # Ensure text isn't empty
        if not clean_text.strip():
            print("Warning: Empty text provided for TTS.")
            return None

        # Limit text length
        if len(clean_text) > MAX_TTS_CHARS:
            print(f"Warning: Text too long ({len(clean_text)} chars), truncating to {MAX_TTS_CHARS}.")
            # Simple truncation, could be smarter (split by sentence)
            clean_text = clean_text[:MAX_TTS_CHARS]
            last_space = clean_text.rfind(' ')
            if last_space != -1:
                clean_text = clean_text[:last_space] + "..." # Truncate at last space

        # Run the potentially blocking TTS generation in a thread
        print(f"Generating audio for: '{clean_text[:100]}...'")
        gen_func = globals()['generate_tts_internal']
        loop = asyncio.get_event_loop()
        audio_data, _ = await loop.run_in_executor(
            executor, # Use the thread pool executor
            gen_func,
            TTS_MODEL,
            clean_text,
            VOICEPACKS[voice_id],
            'a' # Language code (assuming 'a' is appropriate)
        )

        if isinstance(audio_data, torch.Tensor):
            # Move tensor to CPU before converting to numpy if it's not already
            audio_np = audio_data.cpu().numpy()
        elif isinstance(audio_data, np.ndarray):
            audio_np = audio_data
        else:
            print("Warning: Unexpected audio data type from TTS.")
            return None

        print(f"Audio generated successfully, shape: {audio_np.shape}")
        return (TTS_SAMPLE_RATE, audio_np)

    except Exception as e:
        import traceback
        print(f"Error generating speech: {str(e)}")
        print(traceback.format_exc()) # Print full traceback for debugging
        return None

# Helper to get voice ID from display name
def get_voice_id(voice_display_name: str) -> str:
    """Maps the user-friendly voice name to the internal voice ID."""
    return VOICE_CHOICES.get(voice_display_name, 'af') # Default to 'af' if not found

# --- Main Processing Logic (Async) ---
async def process_query_async(query: str, history: List[List[str]], selected_voice_display_name: str):
    """Asynchronously process user query: search -> generate answer -> generate speech"""
    if not query:
        yield (
            "Please enter a query.", "", "Search", history, None
        )
        return

    if history is None: history = []
    current_history = history + [[query, "*Searching...*"]]

    # 1. Initial state: Searching
    yield (
        "*Searching & Thinking...*",
        "<div class='searching'>Searching the web...</div>",
        gr.Button(value="Searching...", interactive=False), # Disable button
        current_history,
        None
    )

    # 2. Perform Web Search (non-blocking)
    loop = asyncio.get_event_loop()
    web_results = await loop.run_in_executor(executor, get_web_results, query)
    sources_html = format_sources(web_results)

    # Update state: Analyzing results
    current_history[-1][1] = "*Analyzing search results...*"
    yield (
        "*Analyzing search results...*",
        sources_html,
        gr.Button(value="Generating...", interactive=False),
        current_history,
        None
    )

    # 3. Generate Answer (non-blocking, potentially on GPU)
    prompt = format_prompt(query, web_results)
    final_answer = await generate_answer(prompt) # Already async

    # Update state: Answer generated
    current_history[-1][1] = final_answer
    yield (
        final_answer,
        sources_html,
        gr.Button(value="Audio...", interactive=False),
        current_history,
        None
    )

    # 4. Generate Speech (non-blocking, potentially on GPU)
    audio = None
    tts_message = ""
    if not tts_thread.is_alive() and not TTS_ENABLED:
        tts_message = "\n\n*(TTS setup failed or is disabled)*"
    elif tts_thread.is_alive():
        tts_message = "\n\n*(TTS is still initializing, audio may be delayed)*"
    elif TTS_ENABLED:
        voice_id = get_voice_id(selected_voice_display_name)
        audio = await generate_speech(final_answer, voice_id) # Already async
        if audio is None:
            tts_message = f"\n\n*(Audio generation failed for voice '{voice_id}')*"

    # 5. Final state: Show everything
    yield (
        final_answer + tts_message,
        sources_html,
        gr.Button(value="Search", interactive=True), # Re-enable button
        current_history,
        audio
    )


# --- Gradio Interface ---
css = """
/* ... [Your existing CSS remains unchanged] ... */
.gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; }
#header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); }
#header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); }
#header h3 { color: #a8a9ab; }
.search-container { background: #ffffff; border: 1px solid #e0e0e0; border-radius: 12px; box-shadow: 0 4px 16px rgba(0,0,0,0.05); padding: 1.5rem; margin-bottom: 1.5rem; }
.search-box { padding: 0; margin-bottom: 1rem; }
.search-box .gradio-textbox { border-radius: 8px 0 0 8px !important; } /* Style textbox specifically */
.search-box .gradio-dropdown { border-radius: 0 !important; margin-left: -1px; margin-right: -1px;} /* Style dropdown */
.search-box .gradio-button { border-radius: 0 8px 8px 0 !important; } /* Style button */
.search-box input[type="text"] { background: #f7f7f8 !important; border: 1px solid #d1d5db !important; color: #1f2937 !important; transition: all 0.3s ease; height: 42px !important; }
.search-box input[type="text"]:focus { border-color: #2563eb !important; box-shadow: 0 0 0 2px rgba(37, 99, 235, 0.2) !important; background: white !important; }
.search-box input[type="text"]::placeholder { color: #9ca3af !important; }
.search-box button { background: #2563eb !important; border: none !important; color: white !important; box-shadow: 0 1px 2px rgba(0,0,0,0.05) !important; transition: all 0.3s ease !important; height: 44px !important; }
.search-box button:hover { background: #1d4ed8 !important; }
.search-box button:disabled { background: #9ca3af !important; cursor: not-allowed; }
.results-container { background: transparent; padding: 0; margin-top: 1.5rem; }
.answer-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; color: #1f2937; margin-bottom: 1.5rem; box-shadow: 0 2px 8px rgba(0,0,0,0.05); }
.answer-box p { color: #374151; line-height: 1.7; }
.answer-box code { background: #f3f4f6; border-radius: 4px; padding: 2px 4px; color: #4b5563; font-size: 0.9em; }
.sources-box { background: white; border: 1px solid #e0e0e0; border-radius: 10px; padding: 1.5rem; }
.sources-box h3 { margin-top: 0; margin-bottom: 1rem; color: #111827; font-size: 1.2rem; }
.sources-container { margin-top: 0; }
.source-item { display: flex; padding: 10px 0; margin: 0; border-bottom: 1px solid #f3f4f6; transition: background-color 0.2s; }
.source-item:last-child { border-bottom: none; }
/* .source-item:hover { background-color: #f9fafb; } */
.source-number { font-weight: bold; margin-right: 12px; color: #6b7280; width: 20px; text-align: right; flex-shrink: 0;}
.source-content { flex: 1; }
.source-title { color: #2563eb; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; transition: all 0.2s; font-size: 0.95em; }
.source-title:hover { color: #1d4ed8; text-decoration: underline; }
.source-date { color: #6b7280; font-size: 0.8em; margin-left: 8px; }
.source-snippet { color: #4b5563; font-size: 0.9em; line-height: 1.5; }
.chat-history { max-height: 400px; overflow-y: auto; padding: 1rem; background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; margin-top: 1rem; scrollbar-width: thin; scrollbar-color: #d1d5db #f9fafb; }
.chat-history::-webkit-scrollbar { width: 6px; }
.chat-history::-webkit-scrollbar-track { background: #f9fafb; }
.chat-history::-webkit-scrollbar-thumb { background-color: #d1d5db; border-radius: 20px; }
.examples-container { background: #f9fafb; border-radius: 8px; padding: 1rem; margin-top: 1rem; border: 1px solid #e5e7eb; }
.examples-container .gradio-examples { gap: 8px !important; } /* Target examples component */
.examples-container button { background: white !important; border: 1px solid #d1d5db !important; color: #374151 !important; transition: all 0.2s; margin: 0 !important; font-size: 0.9em !important; padding: 6px 12px !important; }
.examples-container button:hover { background: #f3f4f6 !important; border-color: #adb5bd !important; }
.markdown-content { color: #374151 !important; font-size: 1rem; line-height: 1.7; }
.markdown-content h1, .markdown-content h2, .markdown-content h3 { color: #111827 !important; margin-top: 1.2em !important; margin-bottom: 0.6em !important; font-weight: 600; }
.markdown-content h1 { font-size: 1.6em !important; border-bottom: 1px solid #e5e7eb; padding-bottom: 0.3em; }
.markdown-content h2 { font-size: 1.4em !important; border-bottom: 1px solid #e5e7eb; padding-bottom: 0.3em;}
.markdown-content h3 { font-size: 1.2em !important; }
.markdown-content a { color: #2563eb !important; text-decoration: none !important; transition: all 0.2s; }
.markdown-content a:hover { color: #1d4ed8 !important; text-decoration: underline !important; }
.markdown-content code { background: #f3f4f6 !important; padding: 2px 6px !important; border-radius: 4px !important; font-family: monospace !important; color: #4b5563; font-size: 0.9em; }
.markdown-content pre { background: #f3f4f6 !important; padding: 12px !important; border-radius: 8px !important; overflow-x: auto !important; border: 1px solid #e5e7eb;}
.markdown-content pre code { background: transparent !important; padding: 0 !important; border: none !important; font-size: 0.9em;}
.markdown-content blockquote { border-left: 4px solid #d1d5db !important; padding-left: 1em !important; margin-left: 0 !important; color: #6b7280 !important; }
.markdown-content table { border-collapse: collapse !important; width: 100% !important; margin: 1em 0; }
.markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #d1d5db !important; text-align: left;}
.markdown-content th { background: #f9fafb !important; font-weight: 600; }
.accordion { background: #f9fafb !important; border: 1px solid #e5e7eb !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: none !important; }
.accordion > .label-wrap { padding: 10px 15px !important; } /* Style accordion header */
.voice-selector { margin: 0; padding: 0; }
.voice-selector div[data-testid="dropdown"] { /* Target the specific dropdown container */ height: 44px !important; }
.voice-selector select { background: white !important; color: #374151 !important; border: 1px solid #d1d5db !important; border-left: none !important; border-right: none !important; border-radius: 0 !important; height: 100% !important; padding: 0 10px !important; transition: all 0.2s; appearance: none !important; -webkit-appearance: none !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%236b7280' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important; background-position: right 0.5rem center !important; background-repeat: no-repeat !important; background-size: 1.5em 1.5em !important; padding-right: 2.5rem !important; }
.voice-selector select:focus { border-color: #2563eb !important; box-shadow: none !important; }
.audio-player { margin-top: 1rem; background: #f9fafb !important; border-radius: 8px !important; padding: 0.5rem !important; border: 1px solid #e5e7eb;}
.audio-player audio { width: 100% !important; }
.searching, .error { padding: 1rem; border-radius: 8px; text-align: center; margin: 1rem 0; border: 1px dashed; }
.searching { background: #eff6ff; color: #3b82f6; border-color: #bfdbfe; }
.error { background: #fef2f2; color: #ef4444; border-color: #fecaca; }
.no-sources { padding: 1rem; text-align: center; color: #6b7280; background: #f9fafb; border-radius: 8px; border: 1px solid #e5e7eb;}
@keyframes pulse { 0% { opacity: 0.7; } 50% { opacity: 1; } 100% { opacity: 0.7; } }
.searching span { animation: pulse 1.5s infinite ease-in-out; display: inline-block; } /* Add span for animation */
.dark .gradio-container { background-color: #111827 !important; }
.dark #header { background: linear-gradient(135deg, #1f2937, #374151); }
.dark #header h3 { color: #9ca3af; }
.dark .search-container { background: #1f2937; border-color: #374151; }
.dark .search-box input[type="text"] { background: #374151 !important; border-color: #4b5563 !important; color: #e5e7eb !important; }
.dark .search-box input[type="text"]:focus { border-color: #3b82f6 !important; background: #4b5563 !important; box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.3) !important; }
.dark .search-box input[type="text"]::placeholder { color: #9ca3af !important; }
.dark .search-box button { background: #3b82f6 !important; }
.dark .search-box button:hover { background: #2563eb !important; }
.dark .search-box button:disabled { background: #4b5563 !important; }
.dark .answer-box { background: #1f2937; border-color: #374151; color: #e5e7eb; }
.dark .answer-box p { color: #d1d5db; }
.dark .answer-box code { background: #374151; color: #9ca3af; }
.dark .sources-box { background: #1f2937; border-color: #374151; }
.dark .sources-box h3 { color: #f9fafb; }
.dark .source-item { border-bottom-color: #374151; }
.dark .source-item:hover { background-color: #374151; }
.dark .source-number { color: #9ca3af; }
.dark .source-title { color: #60a5fa; }
.dark .source-title:hover { color: #93c5fd; }
.dark .source-snippet { color: #d1d5db; }
.dark .chat-history { background: #374151; border-color: #4b5563; scrollbar-color: #4b5563 #374151; }
.dark .chat-history::-webkit-scrollbar-track { background: #374151; }
.dark .chat-history::-webkit-scrollbar-thumb { background-color: #4b5563; }
.dark .examples-container { background: #374151; border-color: #4b5563; }
.dark .examples-container button { background: #1f2937 !important; border-color: #4b5563 !important; color: #d1d5db !important; }
.dark .examples-container button:hover { background: #4b5563 !important; border-color: #6b7280 !important; }
.dark .markdown-content { color: #d1d5db !important; }
.dark .markdown-content h1, .dark .markdown-content h2, .dark .markdown-content h3 { color: #f9fafb !important; border-bottom-color: #4b5563; }
.dark .markdown-content a { color: #60a5fa !important; }
.dark .markdown-content a:hover { color: #93c5fd !important; }
.dark .markdown-content code { background: #374151 !important; color: #9ca3af; }
.dark .markdown-content pre { background: #374151 !important; border-color: #4b5563;}
.dark .markdown-content pre code { background: transparent !important; }
.dark .markdown-content blockquote { border-left-color: #4b5563 !important; color: #9ca3af !important; }
.dark .markdown-content th, .dark .markdown-content td { border-color: #4b5563 !important; }
.dark .markdown-content th { background: #374151 !important; }
.dark .accordion { background: #374151 !important; border-color: #4b5563 !important; }
.dark .voice-selector select { background: #1f2937 !important; color: #d1d5db !important; border-color: #4b5563 !important; background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e") !important;}
.dark .voice-selector select:focus { border-color: #3b82f6 !important; }
.dark .audio-player { background: #374151 !important; border-color: #4b5563;}
.dark .searching { background: #1e3a8a; color: #93c5fd; border-color: #3b82f6; }
.dark .error { background: #7f1d1d; color: #fca5a5; border-color: #ef4444; }
.dark .no-sources { background: #374151; color: #9ca3af; border-color: #4b5563;}

"""

with gr.Blocks(title="AI Search Assistant", css=css, theme=gr.themes.Default(primary_hue="blue")) as demo:
    chat_history = gr.State([])

    with gr.Column(): # Main container
        with gr.Column(elem_id="header"):
            gr.Markdown("# πŸ” AI Search Assistant")
            gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")

        with gr.Column(elem_classes="search-container"):
            with gr.Row(elem_classes="search-box", equal_height=True):
                search_input = gr.Textbox(
                    label="",
                    placeholder="Ask anything...",
                    scale=5,
                    container=False, # Important for direct styling
                    elem_classes="gradio-textbox"
                )
                voice_select = gr.Dropdown(
                    choices=list(VOICE_CHOICES.keys()),
                    value=list(VOICE_CHOICES.keys())[0],
                    label="", # No label needed here
                    scale=2,
                    container=False, # Important
                    elem_classes="voice-selector gradio-dropdown"
                )
                search_btn = gr.Button(
                    "Search",
                    variant="primary",
                    scale=1,
                    elem_classes="gradio-button"
                )

            with gr.Row(elem_classes="results-container", equal_height=False):
                with gr.Column(scale=3): # Wider column for answer + history
                    with gr.Column(elem_classes="answer-box"):
                        answer_output = gr.Markdown(elem_classes="markdown-content", value="*Your answer will appear here...*")
                        # Audio player below the answer
                        audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player", type="numpy") # Expect numpy array

                    with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
                        chat_history_display = gr.Chatbot(elem_classes="chat-history", label="History", height=300)

                with gr.Column(scale=2): # Narrower column for sources
                     with gr.Column(elem_classes="sources-box"):
                        gr.Markdown("### Sources")
                        sources_output = gr.HTML(value="<div class='no-sources'>Sources will appear here after searching.</div>")

            with gr.Row(elem_classes="examples-container"):
                 gr.Examples(
                    examples=[
                        "Latest news about renewable energy",
                        "Explain the concept of Large Language Models (LLMs)",
                        "What are the symptoms and prevention tips for the flu?",
                        "Compare Python and JavaScript for web development"
                    ],
                    inputs=search_input,
                    label="Try these examples:",
                    elem_classes="gradio-examples" # Add class for potential styling
                )

    # --- Event Handling ---
    # Use the async function for processing
    async def handle_interaction(query, history, voice_display_name):
        """Wrapper to handle the async generator from process_query_async"""
        try:
            async for update in process_query_async(query, history, voice_display_name):
                # Ensure the button state is updated correctly
                ans_out, src_out, btn_state, hist_display, aud_out = update
                yield ans_out, src_out, btn_state, hist_display, aud_out
        except Exception as e:
            print(f"Error in handle_interaction: {e}")
            import traceback
            traceback.print_exc()
            error_message = f"An unexpected error occurred: {e}"
            # Provide a final error state update
            yield (
                error_message,
                "<div class='error'>Error processing request.</div>",
                gr.Button(value="Search", interactive=True), # Re-enable button on error
                history + [[query, f"*Error: {error_message}*"]],
                None
            )


    # Corrected event listeners: Pass the voice_select component directly
    search_btn.click(
        fn=handle_interaction,
        inputs=[search_input, chat_history, voice_select], # Pass voice_select component
        outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
    )

    search_input.submit(
        fn=handle_interaction,
        inputs=[search_input, chat_history, voice_select], # Pass voice_select component
        outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
    )

if __name__ == "__main__":
    # Launch the app
    demo.queue(max_size=20).launch(debug=True, share=True) # Enable debug for more logs