Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,11 +33,12 @@ except Exception as e:
|
|
| 33 |
print(f"Warning: Initial setup error: {str(e)}")
|
| 34 |
print("Continuing with limited functionality...")
|
| 35 |
|
|
|
|
| 36 |
# --- Initialization (Do this ONCE) ---
|
|
|
|
| 37 |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
|
| 38 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 39 |
tokenizer.pad_token = tokenizer.eos_token
|
| 40 |
-
|
| 41 |
# Initialize DeepSeek model
|
| 42 |
model = AutoModelForCausalLM.from_pretrained(
|
| 43 |
model_name,
|
|
@@ -78,12 +79,12 @@ try:
|
|
| 78 |
TTS_ENABLED = True
|
| 79 |
else:
|
| 80 |
print("Warning: Kokoro-82M directory not found. TTS disabled.")
|
| 81 |
-
|
| 82 |
except Exception as e:
|
| 83 |
print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
|
| 84 |
TTS_ENABLED = False
|
| 85 |
|
| 86 |
|
|
|
|
| 87 |
def get_web_results(query: str, max_results: int = 5) -> List[Dict[str, str]]:
|
| 88 |
"""Get web search results using DuckDuckGo"""
|
| 89 |
try:
|
|
@@ -99,19 +100,27 @@ def get_web_results(query: str, max_results: int = 5) -> List[Dict[str, str]]:
|
|
| 99 |
print(f"Error in web search: {e}")
|
| 100 |
return []
|
| 101 |
|
|
|
|
| 102 |
def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
|
| 103 |
"""Format the prompt with web context"""
|
| 104 |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 105 |
context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
|
| 106 |
return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
|
|
|
|
| 107 |
Current Time: {current_time}
|
|
|
|
| 108 |
Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
|
|
|
|
| 109 |
Query: {query}
|
|
|
|
| 110 |
Web Context:
|
| 111 |
{context_lines}
|
|
|
|
| 112 |
Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
|
|
|
|
| 113 |
Answer:"""
|
| 114 |
|
|
|
|
| 115 |
def format_sources(web_results: List[Dict[str, str]]) -> str:
|
| 116 |
"""Format sources with more details"""
|
| 117 |
if not web_results:
|
|
@@ -134,6 +143,7 @@ def format_sources(web_results: List[Dict[str, str]]) -> str:
|
|
| 134 |
sources_html += "</div>"
|
| 135 |
return sources_html
|
| 136 |
|
|
|
|
| 137 |
@spaces.GPU(duration=30)
|
| 138 |
def generate_answer(prompt: str) -> str:
|
| 139 |
"""Generate answer using the DeepSeek model"""
|
|
@@ -158,10 +168,11 @@ def generate_answer(prompt: str) -> str:
|
|
| 158 |
)
|
| 159 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 160 |
|
|
|
|
|
|
|
| 161 |
@spaces.GPU(duration=30)
|
| 162 |
def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model = TTS_MODEL, voicepack = VOICEPACK) -> Tuple[int, np.ndarray] | None:
|
| 163 |
"""Generate speech from text using Kokoro TTS model."""
|
| 164 |
-
|
| 165 |
if not TTS_ENABLED or tts_model is None:
|
| 166 |
print("TTS is not enabled or model is not loaded.")
|
| 167 |
return None
|
|
@@ -172,7 +183,6 @@ def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model = TTS_
|
|
| 172 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 173 |
voicepack = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', map_location=device, weights_only=True)
|
| 174 |
|
| 175 |
-
|
| 176 |
# Clean the text
|
| 177 |
clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
|
| 178 |
clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
|
|
@@ -211,12 +221,14 @@ def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model = TTS_
|
|
| 211 |
else:
|
| 212 |
return None
|
| 213 |
|
| 214 |
-
|
| 215 |
except Exception as e:
|
| 216 |
print(f"Error generating speech: {str(e)}")
|
| 217 |
import traceback
|
| 218 |
traceback.print_exc()
|
| 219 |
return None
|
|
|
|
|
|
|
|
|
|
| 220 |
def process_query(query: str, history: List[List[str]], selected_voice: str = 'af') -> Dict[str, Any]:
|
| 221 |
"""Process user query with streaming effect"""
|
| 222 |
try:
|
|
@@ -228,6 +240,7 @@ def process_query(query: str, history: List[List[str]], selected_voice: str = 'a
|
|
| 228 |
sources_html = format_sources(web_results)
|
| 229 |
|
| 230 |
current_history = history + [[query, "*Searching...*"]]
|
|
|
|
| 231 |
yield {
|
| 232 |
answer_output: gr.Markdown("*Searching & Thinking...*"),
|
| 233 |
sources_output: gr.HTML(sources_html),
|
|
@@ -244,6 +257,7 @@ def process_query(query: str, history: List[List[str]], selected_voice: str = 'a
|
|
| 244 |
# Update history *before* TTS (important for correct display)
|
| 245 |
updated_history = history + [[query, final_answer]]
|
| 246 |
|
|
|
|
| 247 |
# Generate speech from the answer (only if enabled)
|
| 248 |
if TTS_ENABLED:
|
| 249 |
yield { # Intermediate update before TTS
|
|
@@ -261,6 +275,8 @@ def process_query(query: str, history: List[List[str]], selected_voice: str = 'a
|
|
| 261 |
else:
|
| 262 |
audio = None
|
| 263 |
|
|
|
|
|
|
|
| 264 |
yield {
|
| 265 |
answer_output: gr.Markdown(final_answer),
|
| 266 |
sources_output: gr.HTML(sources_html),
|
|
@@ -273,7 +289,6 @@ def process_query(query: str, history: List[List[str]], selected_voice: str = 'a
|
|
| 273 |
error_message = str(e)
|
| 274 |
if "GPU quota" in error_message:
|
| 275 |
error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
|
| 276 |
-
|
| 277 |
yield {
|
| 278 |
answer_output: gr.Markdown(f"Error: {error_message}"),
|
| 279 |
sources_output: gr.HTML(sources_html), #Still show sources on error
|
|
|
|
| 33 |
print(f"Warning: Initial setup error: {str(e)}")
|
| 34 |
print("Continuing with limited functionality...")
|
| 35 |
|
| 36 |
+
|
| 37 |
# --- Initialization (Do this ONCE) ---
|
| 38 |
+
|
| 39 |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
|
| 40 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 41 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
| 42 |
# Initialize DeepSeek model
|
| 43 |
model = AutoModelForCausalLM.from_pretrained(
|
| 44 |
model_name,
|
|
|
|
| 79 |
TTS_ENABLED = True
|
| 80 |
else:
|
| 81 |
print("Warning: Kokoro-82M directory not found. TTS disabled.")
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
|
| 84 |
TTS_ENABLED = False
|
| 85 |
|
| 86 |
|
| 87 |
+
|
| 88 |
def get_web_results(query: str, max_results: int = 5) -> List[Dict[str, str]]:
|
| 89 |
"""Get web search results using DuckDuckGo"""
|
| 90 |
try:
|
|
|
|
| 100 |
print(f"Error in web search: {e}")
|
| 101 |
return []
|
| 102 |
|
| 103 |
+
|
| 104 |
def format_prompt(query: str, context: List[Dict[str, str]]) -> str:
|
| 105 |
"""Format the prompt with web context"""
|
| 106 |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 107 |
context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
|
| 108 |
return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
|
| 109 |
+
|
| 110 |
Current Time: {current_time}
|
| 111 |
+
|
| 112 |
Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
|
| 113 |
+
|
| 114 |
Query: {query}
|
| 115 |
+
|
| 116 |
Web Context:
|
| 117 |
{context_lines}
|
| 118 |
+
|
| 119 |
Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
|
| 120 |
+
|
| 121 |
Answer:"""
|
| 122 |
|
| 123 |
+
|
| 124 |
def format_sources(web_results: List[Dict[str, str]]) -> str:
|
| 125 |
"""Format sources with more details"""
|
| 126 |
if not web_results:
|
|
|
|
| 143 |
sources_html += "</div>"
|
| 144 |
return sources_html
|
| 145 |
|
| 146 |
+
|
| 147 |
@spaces.GPU(duration=30)
|
| 148 |
def generate_answer(prompt: str) -> str:
|
| 149 |
"""Generate answer using the DeepSeek model"""
|
|
|
|
| 168 |
)
|
| 169 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 170 |
|
| 171 |
+
|
| 172 |
+
|
| 173 |
@spaces.GPU(duration=30)
|
| 174 |
def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model = TTS_MODEL, voicepack = VOICEPACK) -> Tuple[int, np.ndarray] | None:
|
| 175 |
"""Generate speech from text using Kokoro TTS model."""
|
|
|
|
| 176 |
if not TTS_ENABLED or tts_model is None:
|
| 177 |
print("TTS is not enabled or model is not loaded.")
|
| 178 |
return None
|
|
|
|
| 183 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 184 |
voicepack = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', map_location=device, weights_only=True)
|
| 185 |
|
|
|
|
| 186 |
# Clean the text
|
| 187 |
clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
|
| 188 |
clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
|
|
|
|
| 221 |
else:
|
| 222 |
return None
|
| 223 |
|
|
|
|
| 224 |
except Exception as e:
|
| 225 |
print(f"Error generating speech: {str(e)}")
|
| 226 |
import traceback
|
| 227 |
traceback.print_exc()
|
| 228 |
return None
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
def process_query(query: str, history: List[List[str]], selected_voice: str = 'af') -> Dict[str, Any]:
|
| 233 |
"""Process user query with streaming effect"""
|
| 234 |
try:
|
|
|
|
| 240 |
sources_html = format_sources(web_results)
|
| 241 |
|
| 242 |
current_history = history + [[query, "*Searching...*"]]
|
| 243 |
+
|
| 244 |
yield {
|
| 245 |
answer_output: gr.Markdown("*Searching & Thinking...*"),
|
| 246 |
sources_output: gr.HTML(sources_html),
|
|
|
|
| 257 |
# Update history *before* TTS (important for correct display)
|
| 258 |
updated_history = history + [[query, final_answer]]
|
| 259 |
|
| 260 |
+
|
| 261 |
# Generate speech from the answer (only if enabled)
|
| 262 |
if TTS_ENABLED:
|
| 263 |
yield { # Intermediate update before TTS
|
|
|
|
| 275 |
else:
|
| 276 |
audio = None
|
| 277 |
|
| 278 |
+
|
| 279 |
+
|
| 280 |
yield {
|
| 281 |
answer_output: gr.Markdown(final_answer),
|
| 282 |
sources_output: gr.HTML(sources_html),
|
|
|
|
| 289 |
error_message = str(e)
|
| 290 |
if "GPU quota" in error_message:
|
| 291 |
error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
|
|
|
|
| 292 |
yield {
|
| 293 |
answer_output: gr.Markdown(f"Error: {error_message}"),
|
| 294 |
sources_output: gr.HTML(sources_html), #Still show sources on error
|