sagar007's picture
Update app.py
b8c63a2 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from duckduckgo_search import DDGS
import time
import torch
from datetime import datetime
import os
import subprocess
import numpy as np
from typing import List, Dict, Tuple, Any
from functools import lru_cache
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
# --- Configuration ---
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
MAX_SEARCH_RESULTS = 5
TTS_SAMPLE_RATE = 24000
MAX_TTS_CHARS = 1000
GPU_DURATION = 30 # for spaces.GPU decorator
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95
# --- Initialization ---
# Initialize model and tokenizer with better error handling
try:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
offload_folder="offload",
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
print("Model and tokenizer loaded successfully")
except Exception as e:
print(f"Error initializing model: {str(e)}")
raise
# --- TTS Setup ---
VOICE_CHOICES = {
'πŸ‡ΊπŸ‡Έ Female (Default)': 'af',
'πŸ‡ΊπŸ‡Έ Bella': 'af_bella',
'πŸ‡ΊπŸ‡Έ Sarah': 'af_sarah',
'πŸ‡ΊπŸ‡Έ Nicole': 'af_nicole'
}
TTS_ENABLED = False
TTS_MODEL = None
VOICEPACKS = {} # Cache voice packs
# Initialize Kokoro TTS in a separate thread to avoid blocking startup
def setup_tts():
global TTS_ENABLED, TTS_MODEL, VOICEPACKS
try:
# Install dependencies first
subprocess.run(['git', 'lfs', 'install'], check=True)
if not os.path.exists('Kokoro-82M'):
subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
# Install espeak
try:
subprocess.run(['apt-get', 'update'], check=True)
subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
except subprocess.CalledProcessError:
try:
subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
except subprocess.CalledProcessError:
print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
# Set up Kokoro TTS
if os.path.exists('Kokoro-82M'):
import sys
sys.path.append('Kokoro-82M')
from models import build_model
from kokoro import generate
# Make these functions accessible globally
globals()['build_model'] = build_model
globals()['generate_tts'] = generate
device = 'cuda' if torch.cuda.is_available() else 'cpu'
TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
# Preload default voice
default_voice = 'af'
VOICEPACKS[default_voice] = torch.load(f'Kokoro-82M/voices/{default_voice}.pt',
map_location=device,
weights_only=True)
# Preload other common voices to reduce latency
for voice_name in ['af_bella', 'af_sarah', 'af_nicole']:
try:
voice_path = f'Kokoro-82M/voices/{voice_name}.pt'
if os.path.exists(voice_path):
VOICEPACKS[voice_name] = torch.load(voice_path,
map_location=device,
weights_only=True)
except Exception as e:
print(f"Warning: Could not preload voice {voice_name}: {str(e)}")
TTS_ENABLED = True
print("TTS setup completed successfully")
else:
print("Warning: Kokoro-82M directory not found. TTS disabled.")
except Exception as e:
print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
TTS_ENABLED = False
# Start TTS setup in a separate thread
threading.Thread(target=setup_tts, daemon=True).start()
# --- Search and Generation Functions ---
@lru_cache(maxsize=128)
def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
"""Get web search results using DuckDuckGo with caching for improved performance"""
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
return [{
"title": result.get("title", ""),
"snippet": result.get("body", ""),
"url": result.get("href", ""),
"date": result.get("published", "")
} for result in results]
except Exception as e:
print(f"Error in web search: {e}")
return []
def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
"""Format the prompt with web context"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
Current Time: {current_time}
Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
Query: {query}
Web Context:
{context_lines}
Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
Answer:"""
def format_sources(web_results: List[Dict[str, str]]) -> str:
"""Format sources with more details"""
if not web_results:
return "<div class='no-sources'>No sources available</div>"
sources_html = "<div class='sources-container'>"
for i, res in enumerate(web_results, 1):
title = res["title"] or "Source"
date = f"<span class='source-date'>{res['date']}</span>" if res.get('date') else ""
snippet = res.get("snippet", "")[:150] + "..." if res.get("snippet") else ""
sources_html += f"""
<div class='source-item'>
<div class='source-number'>[{i}]</div>
<div class='source-content'>
<a href="{res['url']}" target="_blank" class='source-title'>{title}</a>
{date}
<div class='source-snippet'>{snippet}</div>
</div>
</div>
"""
sources_html += "</div>"
return sources_html
@spaces.GPU(duration=GPU_DURATION)
def generate_answer(prompt: str) -> str:
"""Generate answer using the DeepSeek model with optimized settings"""
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
return_attention_mask=True
).to(model.device)
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
@spaces.GPU(duration=GPU_DURATION)
def generate_speech(text: str, voice_name: str = 'af') -> Tuple[int, np.ndarray] | None:
"""Generate speech from text using Kokoro TTS model with improved error handling and caching."""
global VOICEPACKS, TTS_MODEL, TTS_ENABLED
if not TTS_ENABLED or TTS_MODEL is None:
return None
try:
from kokoro import generate as generate_tts
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load voicepack if needed
if voice_name not in VOICEPACKS:
voice_file = f'Kokoro-82M/voices/{voice_name}.pt'
if not os.path.exists(voice_file):
print(f"Voicepack {voice_name}.pt not found. Falling back to default 'af'.")
voice_name = 'af'
# Check if default is already loaded
if voice_name not in VOICEPACKS:
voice_file = f'Kokoro-82M/voices/{voice_name}.pt'
if os.path.exists(voice_file):
VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True)
else:
print("Default voicepack 'af.pt' not found. Cannot generate audio.")
return None
else:
VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True)
# Clean the text
clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
# Split long text into chunks
max_chars = MAX_TTS_CHARS
chunks = []
if len(clean_text) > max_chars:
sentences = clean_text.split('.')
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) + 1 < max_chars:
current_chunk += sentence + "."
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + "."
if current_chunk:
chunks.append(current_chunk.strip())
else:
chunks = [clean_text]
# Generate audio for each chunk
audio_chunks = []
for chunk in chunks:
if chunk.strip():
chunk_audio, _ = generate_tts(TTS_MODEL, chunk, VOICEPACKS[voice_name], lang='a')
if isinstance(chunk_audio, torch.Tensor):
chunk_audio = chunk_audio.cpu().numpy()
audio_chunks.append(chunk_audio)
# Concatenate chunks
if audio_chunks:
final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
return (TTS_SAMPLE_RATE, final_audio)
return None
except Exception as e:
print(f"Error generating speech: {str(e)}")
return None
# --- Asynchronous Processing ---
async def async_web_search(query: str) -> List[Dict[str, str]]:
"""Run web search in a non-blocking way"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, get_web_results, query)
async def async_answer_generation(prompt: str) -> str:
"""Run answer generation in a non-blocking way"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, generate_answer, prompt)
async def async_speech_generation(text: str, voice_name: str) -> Tuple[int, np.ndarray] | None:
"""Run speech generation in a non-blocking way"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, generate_speech, text, voice_name)
def process_query(query: str, history: List[List[str]], selected_voice: str = 'af'):
"""Process user query with streaming effect and non-blocking operations"""
try:
if history is None:
history = []
# Start the search task
current_history = history + [[query, "*Searching...*"]]
# Yield initial searching state
yield (
"*Searching & Thinking...*", # answer_output (Markdown)
"<div class='searching'>Searching for results...</div>", # sources_output (HTML)
"Searching...", # search_btn (Button)
current_history, # chat_history_display (Chatbot)
None # audio_output (Audio)
)
# Get web results
web_results = get_web_results(query)
sources_html = format_sources(web_results)
# Update with the search results obtained
yield (
"*Analyzing search results...*", # answer_output
sources_html, # sources_output
"Generating answer...", # search_btn
current_history, # chat_history_display
None # audio_output
)
# Generate answer
prompt = format_prompt(query, web_results)
answer = generate_answer(prompt)
final_answer = answer.split("Answer:")[-1].strip()
# Update history before TTS
updated_history = history + [[query, final_answer]]
# Update with the answer before generating speech
yield (
final_answer, # answer_output
sources_html, # sources_output
"Generating audio...", # search_btn
updated_history, # chat_history_display
None # audio_output
)
# Generate speech (but don't block if TTS is still initializing)
audio = None
if TTS_ENABLED and TTS_MODEL is not None:
try:
audio = generate_speech(final_answer, selected_voice)
if audio is None:
final_answer += "\n\n*Audio generation failed. The voicepack may be missing or incompatible.*"
except Exception as e:
final_answer += f"\n\n*Error generating audio: {str(e)}*"
else:
final_answer += "\n\n*TTS is still initializing or is disabled. Try again in a moment.*"
# Yield final result
yield (
final_answer, # answer_output
sources_html, # sources_output
"Search", # search_btn
updated_history, # chat_history_display
audio # audio_output
)
except Exception as e:
error_message = str(e)
if "GPU quota" in error_message:
error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
yield (
f"Error: {error_message}", # answer_output
"<div class='error'>An error occurred during search</div>", # sources_output
"Search", # search_btn
history + [[query, f"*Error: {error_message}*"]], # chat_history_display
None # audio_output
)
# --- Improved UI ---
css = """
.gradio-container {
max-width: 1200px !important;
background-color: #f7f7f8 !important;
}
#header {
text-align: center;
margin-bottom: 2rem;
padding: 2rem 0;
background: linear-gradient(135deg, #1a1b1e, #2d2e32);
border-radius: 12px;
color: white;
box-shadow: 0 8px 32px rgba(0,0,0,0.2);
}
#header h1 {
color: white;
font-size: 2.5rem;
margin-bottom: 0.5rem;
text-shadow: 0 2px 4px rgba(0,0,0,0.3);
}
#header h3 {
color: #a8a9ab;
}
.search-container {
background: linear-gradient(135deg, #1a1b1e, #2d2e32);
border-radius: 12px;
box-shadow: 0 4px 16px rgba(0,0,0,0.15);
padding: 1.5rem;
margin-bottom: 1.5rem;
}
.search-box {
padding: 1rem;
background: #2c2d30;
border-radius: 10px;
margin-bottom: 1rem;
box-shadow: inset 0 2px 4px rgba(0,0,0,0.1);
}
.search-box input[type="text"] {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: white !important;
border-radius: 8px !important;
transition: all 0.3s ease;
}
.search-box input[type="text"]:focus {
border-color: #60a5fa !important;
box-shadow: 0 0 0 2px rgba(96, 165, 250, 0.3) !important;
}
.search-box input[type="text"]::placeholder {
color: #a8a9ab !important;
}
.search-box button {
background: #2563eb !important;
border: none !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important;
transition: all 0.3s ease !important;
}
.search-box button:hover {
background: #1d4ed8 !important;
transform: translateY(-1px) !important;
}
.search-box button:active {
transform: translateY(1px) !important;
}
.results-container {
background: #2c2d30;
border-radius: 10px;
padding: 1.5rem;
margin-top: 1.5rem;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
}
.answer-box {
background: #3a3b3e;
border-radius: 10px;
padding: 1.5rem;
color: white;
margin-bottom: 1.5rem;
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
transition: all 0.3s ease;
}
.answer-box:hover {
box-shadow: 0 4px 16px rgba(0,0,0,0.2);
}
.answer-box p {
color: #e5e7eb;
line-height: 1.7;
}
.answer-box code {
background: #2c2d30;
border-radius: 4px;
padding: 2px 4px;
}
.sources-container {
margin-top: 1rem;
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
}
.source-item {
display: flex;
padding: 12px;
margin: 12px 0;
background: #3a3b3e;
border-radius: 8px;
transition: all 0.2s;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.source-item:hover {
background: #4a4b4e;
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.15);
}
.source-number {
font-weight: bold;
margin-right: 12px;
color: #60a5fa;
}
.source-content {
flex: 1;
}
.source-title {
color: #60a5fa;
font-weight: 500;
text-decoration: none;
display: block;
margin-bottom: 6px;
transition: all 0.2s;
}
.source-title:hover {
color: #93c5fd;
text-decoration: underline;
}
.source-date {
color: #a8a9ab;
font-size: 0.9em;
margin-left: 8px;
}
.source-snippet {
color: #e5e7eb;
font-size: 0.9em;
line-height: 1.5;
}
.chat-history {
max-height: 400px;
overflow-y: auto;
padding: 1rem;
background: #2c2d30;
border-radius: 8px;
margin-top: 1rem;
scrollbar-width: thin;
scrollbar-color: #4a4b4e #2c2d30;
}
.chat-history::-webkit-scrollbar {
width: 8px;
}
.chat-history::-webkit-scrollbar-track {
background: #2c2d30;
}
.chat-history::-webkit-scrollbar-thumb {
background-color: #4a4b4e;
border-radius: 20px;
}
.examples-container {
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.examples-container button {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: #e5e7eb !important;
transition: all 0.2s;
margin: 4px !important;
}
.examples-container button:hover {
background: #4a4b4e !important;
transform: translateY(-1px);
}
.markdown-content {
color: #e5e7eb !important;
}
.markdown-content h1, .markdown-content h2, .markdown-content h3 {
color: white !important;
margin-top: 1.2em !important;
margin-bottom: 0.8em !important;
}
.markdown-content h1 {
font-size: 1.7em !important;
}
.markdown-content h2 {
font-size: 1.5em !important;
}
.markdown-content h3 {
font-size: 1.3em !important;
}
.markdown-content a {
color: #60a5fa !important;
text-decoration: none !important;
transition: all 0.2s;
}
.markdown-content a:hover {
color: #93c5fd !important;
text-decoration: underline !important;
}
.markdown-content code {
background: #2c2d30 !important;
padding: 2px 6px !important;
border-radius: 4px !important;
font-family: monospace !important;
}
.markdown-content pre {
background: #2c2d30 !important;
padding: 12px !important;
border-radius: 8px !important;
overflow-x: auto !important;
}
.markdown-content blockquote {
border-left: 4px solid #60a5fa !important;
padding-left: 1em !important;
margin-left: 0 !important;
color: #a8a9ab !important;
}
.markdown-content table {
border-collapse: collapse !important;
width: 100% !important;
}
.markdown-content th, .markdown-content td {
padding: 8px 12px !important;
border: 1px solid #4a4b4e !important;
}
.markdown-content th {
background: #2c2d30 !important;
}
.accordion {
background: #2c2d30 !important;
border-radius: 8px !important;
margin-top: 1rem !important;
box-shadow: 0 2px 8px rgba(0,0,0,0.1) !important;
}
.voice-selector {
margin-top: 1rem;
background: #2c2d30;
border-radius: 8px;
padding: 0.5rem;
}
.voice-selector select {
background: #3a3b3e !important;
color: white !important;
border: 1px solid #4a4b4e !important;
border-radius: 4px !important;
padding: 8px !important;
transition: all 0.2s;
}
.voice-selector select:focus {
border-color: #60a5fa !important;
}
.audio-player {
margin-top: 1rem;
background: #2c2d30 !important;
border-radius: 8px !important;
padding: 0.5rem !important;
}
.audio-player audio {
width: 100% !important;
}
.searching, .error {
padding: 1rem;
border-radius: 8px;
text-align: center;
margin: 1rem 0;
}
.searching {
background: rgba(96, 165, 250, 0.1);
color: #60a5fa;
}
.error {
background: rgba(239, 68, 68, 0.1);
color: #ef4444;
}
.no-sources {
padding: 1rem;
text-align: center;
color: #a8a9ab;
background: #2c2d30;
border-radius: 8px;
}
@keyframes pulse {
0% { opacity: 0.6; }
50% { opacity: 1; }
100% { opacity: 0.6; }
}
.searching {
animation: pulse 1.5s infinite;
}
"""
# --- Gradio Interface ---
with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
chat_history = gr.State([])
with gr.Column(elem_id="header"):
gr.Markdown("# πŸ” AI Search Assistant")
gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
with gr.Column(elem_classes="search-container"):
with gr.Row(elem_classes="search-box"):
search_input = gr.Textbox(
label="",
placeholder="Ask anything...",
scale=5,
container=False
)
voice_select = gr.Dropdown(
choices=list(VOICE_CHOICES.keys()),
value=list(VOICE_CHOICES.keys())[0],
label="Voice",
elem_classes="voice-selector",
scale=1
)
search_btn = gr.Button("Search", variant="primary", scale=1)
with gr.Row(elem_classes="results-container"):
with gr.Column(scale=2):
with gr.Column(elem_classes="answer-box"):
answer_output = gr.Markdown(elem_classes="markdown-content")
with gr.Row():
audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player")
with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
chat_history_display = gr.Chatbot(elem_classes="chat-history")
with gr.Column(scale=1):
with gr.Column(elem_classes="sources-box"):
gr.Markdown("### Sources")
sources_output = gr.HTML()
with gr.Row(elem_classes="examples-container"):
gr.Examples(
examples=[
"Latest news about artificial intelligence advances",
"How does blockchain technology work?",
"What are the best practices for sustainable living?",
"Compare electric vehicles and traditional cars"
],
inputs=search_input,
label="Try these examples"
)
# Handle voice selection mapping
def get_voice_id(voice_name):
return VOICE_CHOICES.get(voice_name, 'af')
# Handle interactions
search_btn.click(
fn=process_query,
inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
# Also trigger search on Enter key
search_input.submit(
fn=process_query,
inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
if __name__ == "__main__":
# Start the app with optimized settings
demo.queue(concurrency_count=5, max_size=20).launch(share=True)