DeepSeekR1-LIVE / app.py
sagar007's picture
Update app.py
8a9a6c3 verified
raw
history blame
10.8 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import spaces
from duckduckgo_search import DDGS
import time
import torch
from datetime import datetime
import gc # For manual garbage collection
# Initialize model and tokenizer with optimizations
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# Load config first to set optimal parameters
config = AutoConfig.from_pretrained(model_name)
config.use_cache = True # Enable KV-caching for faster inference
# Initialize tokenizer with optimizations
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length=256, # Reduced for faster processing
padding_side="left",
truncation_side="left",
)
tokenizer.pad_token = tokenizer.eos_token
# Load model with optimizations
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
device_map="cpu",
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
)
# Enable model optimizations
model.eval() # Set to evaluation mode
torch.set_num_threads(4) # Limit CPU threads for better performance
def get_web_results(query, max_results=3): # Reduced max results
"""Get web search results using DuckDuckGo"""
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
return [{
"title": result.get("title", ""),
"snippet": result["body"][:200], # Limit snippet length
"url": result["href"],
"date": result.get("published", "")
} for result in results]
except Exception as e:
return []
def format_prompt(query, context):
"""Format the prompt with web context - optimized version"""
context_lines = '\n'.join([f'[{i+1}] {res["snippet"]}'
for i, res in enumerate(context)])
return f"""Answer this query using the context: {query}\n\nContext:\n{context_lines}\n\nAnswer:"""
def format_sources(web_results):
"""Format sources with more details"""
if not web_results:
return "<div class='no-sources'>No sources available</div>"
sources_html = "<div class='sources-container'>"
for i, res in enumerate(web_results, 1):
title = res["title"] or "Source"
date = f"<span class='source-date'>{res['date']}</span>" if res['date'] else ""
sources_html += f"""
<div class='source-item'>
<div class='source-number'>[{i}]</div>
<div class='source-content'>
<a href="{res['url']}" target="_blank" class='source-title'>{title}</a>
{date}
<div class='source-snippet'>{res['snippet'][:150]}...</div>
</div>
</div>
"""
sources_html += "</div>"
return sources_html
def generate_answer(prompt):
"""Generate answer using the DeepSeek model - optimized version"""
try:
# Clear CUDA cache and garbage collect
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
return_attention_mask=True
)
with torch.no_grad(): # Disable gradient calculation
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=100, # Further reduced for speed
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
num_beams=1,
early_stopping=True,
no_repeat_ngram_size=3,
length_penalty=1.0
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split('Answer:')[-1].strip()
except Exception as e:
return f"Error generating response: {str(e)}"
def process_query(query, history):
"""Process user query with optimized streaming effect"""
try:
if history is None:
history = []
# Get web results first
web_results = get_web_results(query)
sources_html = format_sources(web_results)
# Show searching status
yield {
answer_output: gr.Markdown("*Searching and generating response...*"),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Please wait...", interactive=False),
chat_history_display: history + [[query, "*Processing...*"]]
}
# Generate answer with timeout protection
prompt = format_prompt(query, web_results)
answer = generate_answer(prompt)
# Update with final answer
final_history = history + [[query, answer]]
yield {
answer_output: gr.Markdown(answer),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Search", interactive=True),
chat_history_display: final_history
}
except Exception as e:
error_msg = f"Error: {str(e)}"
yield {
answer_output: gr.Markdown(error_msg),
sources_output: gr.HTML("<div>Error fetching sources</div>"),
search_btn: gr.Button("Search", interactive=True),
chat_history_display: history + [[query, error_msg]]
}
# Update the CSS for better contrast and readability
css = """
.gradio-container {
max-width: 1200px !important;
background-color: #f7f7f8 !important;
}
#header {
text-align: center;
margin-bottom: 2rem;
padding: 2rem 0;
background: #1a1b1e;
border-radius: 12px;
color: white;
}
#header h1 {
color: white;
font-size: 2.5rem;
margin-bottom: 0.5rem;
}
#header h3 {
color: #a8a9ab;
}
.search-container {
background: #1a1b1e;
border-radius: 12px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
padding: 1rem;
margin-bottom: 1rem;
}
.search-box {
padding: 1rem;
background: #2c2d30;
border-radius: 8px;
margin-bottom: 1rem;
}
/* Style the input textbox */
.search-box input[type="text"] {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: white !important;
border-radius: 8px !important;
}
.search-box input[type="text"]::placeholder {
color: #a8a9ab !important;
}
/* Style the search button */
.search-box button {
background: #2563eb !important;
border: none !important;
}
/* Results area styling */
.results-container {
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.answer-box {
background: #3a3b3e;
border-radius: 8px;
padding: 1.5rem;
color: white;
margin-bottom: 1rem;
}
.answer-box p {
color: #e5e7eb;
line-height: 1.6;
}
.sources-container {
margin-top: 1rem;
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
}
.source-item {
display: flex;
padding: 12px;
margin: 8px 0;
background: #3a3b3e;
border-radius: 8px;
transition: all 0.2s;
}
.source-item:hover {
background: #4a4b4e;
}
.source-number {
font-weight: bold;
margin-right: 12px;
color: #60a5fa;
}
.source-content {
flex: 1;
}
.source-title {
color: #60a5fa;
font-weight: 500;
text-decoration: none;
display: block;
margin-bottom: 4px;
}
.source-date {
color: #a8a9ab;
font-size: 0.9em;
margin-left: 8px;
}
.source-snippet {
color: #e5e7eb;
font-size: 0.9em;
line-height: 1.4;
}
.chat-history {
max-height: 400px;
overflow-y: auto;
padding: 1rem;
background: #2c2d30;
border-radius: 8px;
margin-top: 1rem;
}
.examples-container {
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.examples-container button {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: #e5e7eb !important;
}
/* Markdown content styling */
.markdown-content {
color: #e5e7eb !important;
}
.markdown-content h1, .markdown-content h2, .markdown-content h3 {
color: white !important;
}
.markdown-content a {
color: #60a5fa !important;
}
/* Accordion styling */
.accordion {
background: #2c2d30 !important;
border-radius: 8px !important;
margin-top: 1rem !important;
}
"""
# Update the Gradio interface layout
with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
chat_history = gr.State([])
with gr.Column(elem_id="header"):
gr.Markdown("# 🔍 AI Search Assistant")
gr.Markdown("### Powered by DeepSeek & Real-time Web Results")
with gr.Column(elem_classes="search-container"):
with gr.Row(elem_classes="search-box"):
search_input = gr.Textbox(
label="",
placeholder="Ask anything...",
scale=5,
container=False
)
search_btn = gr.Button("Search", variant="primary", scale=1)
with gr.Row(elem_classes="results-container"):
with gr.Column(scale=2):
with gr.Column(elem_classes="answer-box"):
answer_output = gr.Markdown(elem_classes="markdown-content")
with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
chat_history_display = gr.Chatbot(elem_classes="chat-history")
with gr.Column(scale=1):
with gr.Column(elem_classes="sources-box"):
gr.Markdown("### Sources")
sources_output = gr.HTML()
with gr.Row(elem_classes="examples-container"):
gr.Examples(
examples=[
"What are the latest developments in quantum computing?",
"Explain the impact of AI on healthcare",
"What are the best practices for sustainable living?",
"How is climate change affecting ocean ecosystems?"
],
inputs=search_input,
label="Try these examples"
)
# Handle interactions
search_btn.click(
fn=process_query,
inputs=[search_input, chat_history],
outputs=[answer_output, sources_output, search_btn, chat_history_display]
)
# Also trigger search on Enter key
search_input.submit(
fn=process_query,
inputs=[search_input, chat_history],
outputs=[answer_output, sources_output, search_btn, chat_history_display]
)
if __name__ == "__main__":
demo.launch(share=True)