Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
import spaces | |
from duckduckgo_search import DDGS | |
import time | |
import torch | |
from datetime import datetime | |
import gc # For manual garbage collection | |
# Initialize model and tokenizer with optimizations | |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
# Load config first to set optimal parameters | |
config = AutoConfig.from_pretrained(model_name) | |
config.use_cache = True # Enable KV-caching for faster inference | |
# Initialize tokenizer with optimizations | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
model_max_length=256, # Reduced for faster processing | |
padding_side="left", | |
truncation_side="left", | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model with optimizations | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
config=config, | |
device_map="cpu", | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32, | |
) | |
# Enable model optimizations | |
model.eval() # Set to evaluation mode | |
torch.set_num_threads(4) # Limit CPU threads for better performance | |
def get_web_results(query, max_results=3): # Reduced max results | |
"""Get web search results using DuckDuckGo""" | |
try: | |
with DDGS() as ddgs: | |
results = list(ddgs.text(query, max_results=max_results)) | |
return [{ | |
"title": result.get("title", ""), | |
"snippet": result["body"][:200], # Limit snippet length | |
"url": result["href"], | |
"date": result.get("published", "") | |
} for result in results] | |
except Exception as e: | |
return [] | |
def format_prompt(query, context): | |
"""Format the prompt with web context - optimized version""" | |
context_lines = '\n'.join([f'[{i+1}] {res["snippet"]}' | |
for i, res in enumerate(context)]) | |
return f"""Answer this query using the context: {query}\n\nContext:\n{context_lines}\n\nAnswer:""" | |
def format_sources(web_results): | |
"""Format sources with more details""" | |
if not web_results: | |
return "<div class='no-sources'>No sources available</div>" | |
sources_html = "<div class='sources-container'>" | |
for i, res in enumerate(web_results, 1): | |
title = res["title"] or "Source" | |
date = f"<span class='source-date'>{res['date']}</span>" if res['date'] else "" | |
sources_html += f""" | |
<div class='source-item'> | |
<div class='source-number'>[{i}]</div> | |
<div class='source-content'> | |
<a href="{res['url']}" target="_blank" class='source-title'>{title}</a> | |
{date} | |
<div class='source-snippet'>{res['snippet'][:150]}...</div> | |
</div> | |
</div> | |
""" | |
sources_html += "</div>" | |
return sources_html | |
def generate_answer(prompt): | |
"""Generate answer using the DeepSeek model - optimized version""" | |
try: | |
# Clear CUDA cache and garbage collect | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=256, | |
return_attention_mask=True | |
) | |
with torch.no_grad(): # Disable gradient calculation | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=100, # Further reduced for speed | |
temperature=0.7, | |
top_p=0.95, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
num_beams=1, | |
early_stopping=True, | |
no_repeat_ngram_size=3, | |
length_penalty=1.0 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.split('Answer:')[-1].strip() | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def process_query(query, history): | |
"""Process user query with optimized streaming effect""" | |
try: | |
if history is None: | |
history = [] | |
# Get web results first | |
web_results = get_web_results(query) | |
sources_html = format_sources(web_results) | |
# Show searching status | |
yield { | |
answer_output: gr.Markdown("*Searching and generating response...*"), | |
sources_output: gr.HTML(sources_html), | |
search_btn: gr.Button("Please wait...", interactive=False), | |
chat_history_display: history + [[query, "*Processing...*"]] | |
} | |
# Generate answer with timeout protection | |
prompt = format_prompt(query, web_results) | |
answer = generate_answer(prompt) | |
# Update with final answer | |
final_history = history + [[query, answer]] | |
yield { | |
answer_output: gr.Markdown(answer), | |
sources_output: gr.HTML(sources_html), | |
search_btn: gr.Button("Search", interactive=True), | |
chat_history_display: final_history | |
} | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
yield { | |
answer_output: gr.Markdown(error_msg), | |
sources_output: gr.HTML("<div>Error fetching sources</div>"), | |
search_btn: gr.Button("Search", interactive=True), | |
chat_history_display: history + [[query, error_msg]] | |
} | |
# Update the CSS for better contrast and readability | |
css = """ | |
.gradio-container { | |
max-width: 1200px !important; | |
background-color: #f7f7f8 !important; | |
} | |
#header { | |
text-align: center; | |
margin-bottom: 2rem; | |
padding: 2rem 0; | |
background: #1a1b1e; | |
border-radius: 12px; | |
color: white; | |
} | |
#header h1 { | |
color: white; | |
font-size: 2.5rem; | |
margin-bottom: 0.5rem; | |
} | |
#header h3 { | |
color: #a8a9ab; | |
} | |
.search-container { | |
background: #1a1b1e; | |
border-radius: 12px; | |
box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
padding: 1rem; | |
margin-bottom: 1rem; | |
} | |
.search-box { | |
padding: 1rem; | |
background: #2c2d30; | |
border-radius: 8px; | |
margin-bottom: 1rem; | |
} | |
/* Style the input textbox */ | |
.search-box input[type="text"] { | |
background: #3a3b3e !important; | |
border: 1px solid #4a4b4e !important; | |
color: white !important; | |
border-radius: 8px !important; | |
} | |
.search-box input[type="text"]::placeholder { | |
color: #a8a9ab !important; | |
} | |
/* Style the search button */ | |
.search-box button { | |
background: #2563eb !important; | |
border: none !important; | |
} | |
/* Results area styling */ | |
.results-container { | |
background: #2c2d30; | |
border-radius: 8px; | |
padding: 1rem; | |
margin-top: 1rem; | |
} | |
.answer-box { | |
background: #3a3b3e; | |
border-radius: 8px; | |
padding: 1.5rem; | |
color: white; | |
margin-bottom: 1rem; | |
} | |
.answer-box p { | |
color: #e5e7eb; | |
line-height: 1.6; | |
} | |
.sources-container { | |
margin-top: 1rem; | |
background: #2c2d30; | |
border-radius: 8px; | |
padding: 1rem; | |
} | |
.source-item { | |
display: flex; | |
padding: 12px; | |
margin: 8px 0; | |
background: #3a3b3e; | |
border-radius: 8px; | |
transition: all 0.2s; | |
} | |
.source-item:hover { | |
background: #4a4b4e; | |
} | |
.source-number { | |
font-weight: bold; | |
margin-right: 12px; | |
color: #60a5fa; | |
} | |
.source-content { | |
flex: 1; | |
} | |
.source-title { | |
color: #60a5fa; | |
font-weight: 500; | |
text-decoration: none; | |
display: block; | |
margin-bottom: 4px; | |
} | |
.source-date { | |
color: #a8a9ab; | |
font-size: 0.9em; | |
margin-left: 8px; | |
} | |
.source-snippet { | |
color: #e5e7eb; | |
font-size: 0.9em; | |
line-height: 1.4; | |
} | |
.chat-history { | |
max-height: 400px; | |
overflow-y: auto; | |
padding: 1rem; | |
background: #2c2d30; | |
border-radius: 8px; | |
margin-top: 1rem; | |
} | |
.examples-container { | |
background: #2c2d30; | |
border-radius: 8px; | |
padding: 1rem; | |
margin-top: 1rem; | |
} | |
.examples-container button { | |
background: #3a3b3e !important; | |
border: 1px solid #4a4b4e !important; | |
color: #e5e7eb !important; | |
} | |
/* Markdown content styling */ | |
.markdown-content { | |
color: #e5e7eb !important; | |
} | |
.markdown-content h1, .markdown-content h2, .markdown-content h3 { | |
color: white !important; | |
} | |
.markdown-content a { | |
color: #60a5fa !important; | |
} | |
/* Accordion styling */ | |
.accordion { | |
background: #2c2d30 !important; | |
border-radius: 8px !important; | |
margin-top: 1rem !important; | |
} | |
""" | |
# Update the Gradio interface layout | |
with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo: | |
chat_history = gr.State([]) | |
with gr.Column(elem_id="header"): | |
gr.Markdown("# 🔍 AI Search Assistant") | |
gr.Markdown("### Powered by DeepSeek & Real-time Web Results") | |
with gr.Column(elem_classes="search-container"): | |
with gr.Row(elem_classes="search-box"): | |
search_input = gr.Textbox( | |
label="", | |
placeholder="Ask anything...", | |
scale=5, | |
container=False | |
) | |
search_btn = gr.Button("Search", variant="primary", scale=1) | |
with gr.Row(elem_classes="results-container"): | |
with gr.Column(scale=2): | |
with gr.Column(elem_classes="answer-box"): | |
answer_output = gr.Markdown(elem_classes="markdown-content") | |
with gr.Accordion("Chat History", open=False, elem_classes="accordion"): | |
chat_history_display = gr.Chatbot(elem_classes="chat-history") | |
with gr.Column(scale=1): | |
with gr.Column(elem_classes="sources-box"): | |
gr.Markdown("### Sources") | |
sources_output = gr.HTML() | |
with gr.Row(elem_classes="examples-container"): | |
gr.Examples( | |
examples=[ | |
"What are the latest developments in quantum computing?", | |
"Explain the impact of AI on healthcare", | |
"What are the best practices for sustainable living?", | |
"How is climate change affecting ocean ecosystems?" | |
], | |
inputs=search_input, | |
label="Try these examples" | |
) | |
# Handle interactions | |
search_btn.click( | |
fn=process_query, | |
inputs=[search_input, chat_history], | |
outputs=[answer_output, sources_output, search_btn, chat_history_display] | |
) | |
# Also trigger search on Enter key | |
search_input.submit( | |
fn=process_query, | |
inputs=[search_input, chat_history], | |
outputs=[answer_output, sources_output, search_btn, chat_history_display] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |