import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import spaces from duckduckgo_search import DDGS import time import torch from datetime import datetime import os import subprocess import numpy as np from typing import List, Dict, Tuple, Any from functools import lru_cache import asyncio import threading from concurrent.futures import ThreadPoolExecutor # --- Configuration --- MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MAX_SEARCH_RESULTS = 5 TTS_SAMPLE_RATE = 24000 MAX_TTS_CHARS = 1000 GPU_DURATION = 30 # for spaces.GPU decorator MAX_NEW_TOKENS = 256 TEMPERATURE = 0.7 TOP_P = 0.95 # --- Initialization --- # Initialize model and tokenizer with better error handling try: print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.pad_token = tokenizer.eos_token print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", offload_folder="offload", low_cpu_mem_usage=True, torch_dtype=torch.float16 ) print("Model and tokenizer loaded successfully") except Exception as e: print(f"Error initializing model: {str(e)}") raise # --- TTS Setup --- VOICE_CHOICES = { 'πŸ‡ΊπŸ‡Έ Female (Default)': 'af', 'πŸ‡ΊπŸ‡Έ Bella': 'af_bella', 'πŸ‡ΊπŸ‡Έ Sarah': 'af_sarah', 'πŸ‡ΊπŸ‡Έ Nicole': 'af_nicole' } TTS_ENABLED = False TTS_MODEL = None VOICEPACKS = {} # Cache voice packs # Initialize Kokoro TTS in a separate thread to avoid blocking startup def setup_tts(): global TTS_ENABLED, TTS_MODEL, VOICEPACKS try: # Install dependencies first['git', 'lfs', 'install'], check=True) if not os.path.exists('Kokoro-82M'):['git', 'clone', ''], check=True) # Install espeak try:['apt-get', 'update'], check=True)['apt-get', 'install', '-y', 'espeak'], check=True) except subprocess.CalledProcessError: try:['apt-get', 'install', '-y', 'espeak-ng'], check=True) except subprocess.CalledProcessError: print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.") # Set up Kokoro TTS if os.path.exists('Kokoro-82M'): import sys sys.path.append('Kokoro-82M') from models import build_model from kokoro import generate # Make these functions accessible globally globals()['build_model'] = build_model globals()['generate_tts'] = generate device = 'cuda' if torch.cuda.is_available() else 'cpu' TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device) # Preload default voice default_voice = 'af' VOICEPACKS[default_voice] = torch.load(f'Kokoro-82M/voices/{default_voice}.pt', map_location=device, weights_only=True) # Preload other common voices to reduce latency for voice_name in ['af_bella', 'af_sarah', 'af_nicole']: try: voice_path = f'Kokoro-82M/voices/{voice_name}.pt' if os.path.exists(voice_path): VOICEPACKS[voice_name] = torch.load(voice_path, map_location=device, weights_only=True) except Exception as e: print(f"Warning: Could not preload voice {voice_name}: {str(e)}") TTS_ENABLED = True print("TTS setup completed successfully") else: print("Warning: Kokoro-82M directory not found. TTS disabled.") except Exception as e: print(f"Warning: Could not initialize Kokoro TTS: {str(e)}") TTS_ENABLED = False # Start TTS setup in a separate thread threading.Thread(target=setup_tts, daemon=True).start() # --- Search and Generation Functions --- @lru_cache(maxsize=128) def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]: """Get web search results using DuckDuckGo with caching for improved performance""" try: with DDGS() as ddgs: results = list(ddgs.text(query, max_results=max_results)) return [{ "title": result.get("title", ""), "snippet": result.get("body", ""), "url": result.get("href", ""), "date": result.get("published", "") } for result in results] except Exception as e: print(f"Error in web search: {e}") return [] def format_prompt(query: str, context: List[Dict[str, str]]) -> str: """Format the prompt with web context""" current_time ="%Y-%m-%d %H:%M:%S") context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context]) return f"""You are an intelligent search assistant. Answer the user's query using the provided web context. Current Time: {current_time} Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context. Query: {query} Web Context: {context_lines} Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing. Answer:""" def format_sources(web_results: List[Dict[str, str]]) -> str: """Format sources with more details""" if not web_results: return "
No sources available
" sources_html = "
" for i, res in enumerate(web_results, 1): title = res["title"] or "Source" date = f"{res['date']}" if res.get('date') else "" snippet = res.get("snippet", "")[:150] + "..." if res.get("snippet") else "" sources_html += f"""
{title} {date}
""" sources_html += "
" return sources_html @spaces.GPU(duration=GPU_DURATION) def generate_answer(prompt: str) -> str: """Generate answer using the DeepSeek model with optimized settings""" inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=512, return_attention_mask=True ).to(model.device) with torch.no_grad(): # Disable gradient calculation for inference outputs = model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, pad_token_id=tokenizer.eos_token_id, do_sample=True, early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) @spaces.GPU(duration=GPU_DURATION) def generate_speech(text: str, voice_name: str = 'af') -> Tuple[int, np.ndarray] | None: """Generate speech from text using Kokoro TTS model with improved error handling and caching.""" global VOICEPACKS, TTS_MODEL, TTS_ENABLED if not TTS_ENABLED or TTS_MODEL is None: return None try: from kokoro import generate as generate_tts device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load voicepack if needed if voice_name not in VOICEPACKS: voice_file = f'Kokoro-82M/voices/{voice_name}.pt' if not os.path.exists(voice_file): print(f"Voicepack {voice_name}.pt not found. Falling back to default 'af'.") voice_name = 'af' # Check if default is already loaded if voice_name not in VOICEPACKS: voice_file = f'Kokoro-82M/voices/{voice_name}.pt' if os.path.exists(voice_file): VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True) else: print("Default voicepack '' not found. Cannot generate audio.") return None else: VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True) # Clean the text clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')]) clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '') # Split long text into chunks max_chars = MAX_TTS_CHARS chunks = [] if len(clean_text) > max_chars: sentences = clean_text.split('.') current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) + 1 < max_chars: current_chunk += sentence + "." else: chunks.append(current_chunk.strip()) current_chunk = sentence + "." if current_chunk: chunks.append(current_chunk.strip()) else: chunks = [clean_text] # Generate audio for each chunk audio_chunks = [] for chunk in chunks: if chunk.strip(): chunk_audio, _ = generate_tts(TTS_MODEL, chunk, VOICEPACKS[voice_name], lang='a') if isinstance(chunk_audio, torch.Tensor): chunk_audio = chunk_audio.cpu().numpy() audio_chunks.append(chunk_audio) # Concatenate chunks if audio_chunks: final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0] return (TTS_SAMPLE_RATE, final_audio) return None except Exception as e: print(f"Error generating speech: {str(e)}") return None # --- Asynchronous Processing --- async def async_web_search(query: str) -> List[Dict[str, str]]: """Run web search in a non-blocking way""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, get_web_results, query) async def async_answer_generation(prompt: str) -> str: """Run answer generation in a non-blocking way""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, generate_answer, prompt) async def async_speech_generation(text: str, voice_name: str) -> Tuple[int, np.ndarray] | None: """Run speech generation in a non-blocking way""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, generate_speech, text, voice_name) def process_query(query: str, history: List[List[str]], selected_voice: str = 'af'): """Process user query with streaming effect and non-blocking operations""" try: if history is None: history = [] # Start the search task current_history = history + [[query, "*Searching...*"]] # Yield initial searching state yield ( "*Searching & Thinking...*", # answer_output (Markdown) "
Searching for results...
", # sources_output (HTML) "Searching...", # search_btn (Button) current_history, # chat_history_display (Chatbot) None # audio_output (Audio) ) # Get web results web_results = get_web_results(query) sources_html = format_sources(web_results) # Update with the search results obtained yield ( "*Analyzing search results...*", # answer_output sources_html, # sources_output "Generating answer...", # search_btn current_history, # chat_history_display None # audio_output ) # Generate answer prompt = format_prompt(query, web_results) answer = generate_answer(prompt) final_answer = answer.split("Answer:")[-1].strip() # Update history before TTS updated_history = history + [[query, final_answer]] # Update with the answer before generating speech yield ( final_answer, # answer_output sources_html, # sources_output "Generating audio...", # search_btn updated_history, # chat_history_display None # audio_output ) # Generate speech (but don't block if TTS is still initializing) audio = None if TTS_ENABLED and TTS_MODEL is not None: try: audio = generate_speech(final_answer, selected_voice) if audio is None: final_answer += "\n\n*Audio generation failed. The voicepack may be missing or incompatible.*" except Exception as e: final_answer += f"\n\n*Error generating audio: {str(e)}*" else: final_answer += "\n\n*TTS is still initializing or is disabled. Try again in a moment.*" # Yield final result yield ( final_answer, # answer_output sources_html, # sources_output "Search", # search_btn updated_history, # chat_history_display audio # audio_output ) except Exception as e: error_message = str(e) if "GPU quota" in error_message: error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets." yield ( f"Error: {error_message}", # answer_output "
An error occurred during search
", # sources_output "Search", # search_btn history + [[query, f"*Error: {error_message}*"]], # chat_history_display None # audio_output ) # --- Improved UI --- css = """ .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; } #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; color: white; box-shadow: 0 8px 32px rgba(0,0,0,0.2); } #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; text-shadow: 0 2px 4px rgba(0,0,0,0.3); } #header h3 { color: #a8a9ab; } .search-container { background: linear-gradient(135deg, #1a1b1e, #2d2e32); border-radius: 12px; box-shadow: 0 4px 16px rgba(0,0,0,0.15); padding: 1.5rem; margin-bottom: 1.5rem; } .search-box { padding: 1rem; background: #2c2d30; border-radius: 10px; margin-bottom: 1rem; box-shadow: inset 0 2px 4px rgba(0,0,0,0.1); } .search-box input[type="text"] { background: #3a3b3e !important; border: 1px solid #4a4b4e !important; color: white !important; border-radius: 8px !important; transition: all 0.3s ease; } .search-box input[type="text"]:focus { border-color: #60a5fa !important; box-shadow: 0 0 0 2px rgba(96, 165, 250, 0.3) !important; } .search-box input[type="text"]::placeholder { color: #a8a9ab !important; } .search-box button { background: #2563eb !important; border: none !important; box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; transition: all 0.3s ease !important; } .search-box button:hover { background: #1d4ed8 !important; transform: translateY(-1px) !important; } .search-box button:active { transform: translateY(1px) !important; } .results-container { background: #2c2d30; border-radius: 10px; padding: 1.5rem; margin-top: 1.5rem; box-shadow: 0 4px 12px rgba(0,0,0,0.1); } .answer-box { background: #3a3b3e; border-radius: 10px; padding: 1.5rem; color: white; margin-bottom: 1.5rem; box-shadow: 0 2px 8px rgba(0,0,0,0.15); transition: all 0.3s ease; } .answer-box:hover { box-shadow: 0 4px 16px rgba(0,0,0,0.2); } .answer-box p { color: #e5e7eb; line-height: 1.7; } .answer-box code { background: #2c2d30; border-radius: 4px; padding: 2px 4px; } .sources-container { margin-top: 1rem; background: #2c2d30; border-radius: 8px; padding: 1rem; } .source-item { display: flex; padding: 12px; margin: 12px 0; background: #3a3b3e; border-radius: 8px; transition: all 0.2s; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .source-item:hover { background: #4a4b4e; transform: translateY(-2px); box-shadow: 0 4px 8px rgba(0,0,0,0.15); } .source-number { font-weight: bold; margin-right: 12px; color: #60a5fa; } .source-content { flex: 1; } .source-title { color: #60a5fa; font-weight: 500; text-decoration: none; display: block; margin-bottom: 6px; transition: all 0.2s; } .source-title:hover { color: #93c5fd; text-decoration: underline; } .source-date { color: #a8a9ab; font-size: 0.9em; margin-left: 8px; } .source-snippet { color: #e5e7eb; font-size: 0.9em; line-height: 1.5; } .chat-history { max-height: 400px; overflow-y: auto; padding: 1rem; background: #2c2d30; border-radius: 8px; margin-top: 1rem; scrollbar-width: thin; scrollbar-color: #4a4b4e #2c2d30; } .chat-history::-webkit-scrollbar { width: 8px; } .chat-history::-webkit-scrollbar-track { background: #2c2d30; } .chat-history::-webkit-scrollbar-thumb { background-color: #4a4b4e; border-radius: 20px; } .examples-container { background: #2c2d30; border-radius: 8px; padding: 1rem; margin-top: 1rem; } .examples-container button { background: #3a3b3e !important; border: 1px solid #4a4b4e !important; color: #e5e7eb !important; transition: all 0.2s; margin: 4px !important; } .examples-container button:hover { background: #4a4b4e !important; transform: translateY(-1px); } .markdown-content { color: #e5e7eb !important; } .markdown-content h1, .markdown-content h2, .markdown-content h3 { color: white !important; margin-top: 1.2em !important; margin-bottom: 0.8em !important; } .markdown-content h1 { font-size: 1.7em !important; } .markdown-content h2 { font-size: 1.5em !important; } .markdown-content h3 { font-size: 1.3em !important; } .markdown-content a { color: #60a5fa !important; text-decoration: none !important; transition: all 0.2s; } .markdown-content a:hover { color: #93c5fd !important; text-decoration: underline !important; } .markdown-content code { background: #2c2d30 !important; padding: 2px 6px !important; border-radius: 4px !important; font-family: monospace !important; } .markdown-content pre { background: #2c2d30 !important; padding: 12px !important; border-radius: 8px !important; overflow-x: auto !important; } .markdown-content blockquote { border-left: 4px solid #60a5fa !important; padding-left: 1em !important; margin-left: 0 !important; color: #a8a9ab !important; } .markdown-content table { border-collapse: collapse !important; width: 100% !important; } .markdown-content th, .markdown-content td { padding: 8px 12px !important; border: 1px solid #4a4b4e !important; } .markdown-content th { background: #2c2d30 !important; } .accordion { background: #2c2d30 !important; border-radius: 8px !important; margin-top: 1rem !important; box-shadow: 0 2px 8px rgba(0,0,0,0.1) !important; } .voice-selector { margin-top: 1rem; background: #2c2d30; border-radius: 8px; padding: 0.5rem; } .voice-selector select { background: #3a3b3e !important; color: white !important; border: 1px solid #4a4b4e !important; border-radius: 4px !important; padding: 8px !important; transition: all 0.2s; } .voice-selector select:focus { border-color: #60a5fa !important; } .audio-player { margin-top: 1rem; background: #2c2d30 !important; border-radius: 8px !important; padding: 0.5rem !important; } .audio-player audio { width: 100% !important; } .searching, .error { padding: 1rem; border-radius: 8px; text-align: center; margin: 1rem 0; } .searching { background: rgba(96, 165, 250, 0.1); color: #60a5fa; } .error { background: rgba(239, 68, 68, 0.1); color: #ef4444; } .no-sources { padding: 1rem; text-align: center; color: #a8a9ab; background: #2c2d30; border-radius: 8px; } @keyframes pulse { 0% { opacity: 0.6; } 50% { opacity: 1; } 100% { opacity: 0.6; } } .searching { animation: pulse 1.5s infinite; } """ # --- Gradio Interface --- with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo: chat_history = gr.State([]) with gr.Column(elem_id="header"): gr.Markdown("# πŸ” AI Search Assistant") gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice") with gr.Column(elem_classes="search-container"): with gr.Row(elem_classes="search-box"): search_input = gr.Textbox( label="", placeholder="Ask anything...", scale=5, container=False ) voice_select = gr.Dropdown( choices=list(VOICE_CHOICES.keys()), value=list(VOICE_CHOICES.keys())[0], label="Voice", elem_classes="voice-selector", scale=1 ) search_btn = gr.Button("Search", variant="primary", scale=1) with gr.Row(elem_classes="results-container"): with gr.Column(scale=2): with gr.Column(elem_classes="answer-box"): answer_output = gr.Markdown(elem_classes="markdown-content") with gr.Row(): audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player") with gr.Accordion("Chat History", open=False, elem_classes="accordion"): chat_history_display = gr.Chatbot(elem_classes="chat-history") with gr.Column(scale=1): with gr.Column(elem_classes="sources-box"): gr.Markdown("### Sources") sources_output = gr.HTML() with gr.Row(elem_classes="examples-container"): gr.Examples( examples=[ "Latest news about artificial intelligence advances", "How does blockchain technology work?", "What are the best practices for sustainable living?", "Compare electric vehicles and traditional cars" ], inputs=search_input, label="Try these examples" ) # Handle voice selection mapping def get_voice_id(voice_name): return VOICE_CHOICES.get(voice_name, 'af') # Handle interactions fn=process_query, inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select], outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output] ) # Also trigger search on Enter key search_input.submit( fn=process_query, inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select], outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output] ) if __name__ == "__main__": # Start the app with optimized settings demo.queue(concurrency_count=5, max_size=20).launch(share=True)