import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import spaces from duckduckgo_search import DDGS import time import torch from datetime import datetime import gc # For manual garbage collection # Initialize model and tokenizer with optimizations model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # Load config first to set optimal parameters config = AutoConfig.from_pretrained(model_name) config.use_cache = True # Enable KV-caching for faster inference # Initialize tokenizer with optimizations tokenizer = AutoTokenizer.from_pretrained( model_name, model_max_length=256, # Reduced for faster processing padding_side="left", truncation_side="left", ) tokenizer.pad_token = tokenizer.eos_token # Load model with optimizations model = AutoModelForCausalLM.from_pretrained( model_name, config=config, device_map="cpu", low_cpu_mem_usage=True, torch_dtype=torch.float32, ) # Enable model optimizations model.eval() # Set to evaluation mode torch.set_num_threads(4) # Limit CPU threads for better performance def get_web_results(query, max_results=3): # Reduced max results """Get web search results using DuckDuckGo""" try: with DDGS() as ddgs: results = list(ddgs.text(query, max_results=max_results)) return [{ "title": result.get("title", ""), "snippet": result["body"][:200], # Limit snippet length "url": result["href"], "date": result.get("published", "") } for result in results] except Exception as e: return [] def format_prompt(query, context): """Format the prompt with web context - optimized version""" context_lines = '\n'.join([f'[{i+1}] {res["snippet"]}' for i, res in enumerate(context)]) return f"""Answer this query using the context: {query}\n\nContext:\n{context_lines}\n\nAnswer:""" def format_sources(web_results): """Format sources with more details""" if not web_results: return "
No sources available
" sources_html = "
" for i, res in enumerate(web_results, 1): title = res["title"] or "Source" date = f"{res['date']}" if res['date'] else "" sources_html += f"""
[{i}]
{title} {date}
{res['snippet'][:150]}...
""" sources_html += "
" return sources_html def generate_answer(prompt): """Generate answer using the DeepSeek model - optimized version""" try: # Clear CUDA cache and garbage collect if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=256, return_attention_mask=True ) with torch.no_grad(): # Disable gradient calculation outputs = model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=100, # Further reduced for speed temperature=0.7, top_p=0.95, pad_token_id=tokenizer.eos_token_id, do_sample=True, num_beams=1, early_stopping=True, no_repeat_ngram_size=3, length_penalty=1.0 ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response.split('Answer:')[-1].strip() except Exception as e: return f"Error generating response: {str(e)}" def process_query(query, history): """Process user query with optimized streaming effect""" try: if history is None: history = [] # Get web results first web_results = get_web_results(query) sources_html = format_sources(web_results) # Show searching status yield { answer_output: gr.Markdown("*Searching and generating response...*"), sources_output: gr.HTML(sources_html), search_btn: gr.Button("Please wait...", interactive=False), chat_history_display: history + [[query, "*Processing...*"]] } # Generate answer with timeout protection prompt = format_prompt(query, web_results) answer = generate_answer(prompt) # Update with final answer final_history = history + [[query, answer]] yield { answer_output: gr.Markdown(answer), sources_output: gr.HTML(sources_html), search_btn: gr.Button("Search", interactive=True), chat_history_display: final_history } except Exception as e: error_msg = f"Error: {str(e)}" yield { answer_output: gr.Markdown(error_msg), sources_output: gr.HTML("
Error fetching sources
"), search_btn: gr.Button("Search", interactive=True), chat_history_display: history + [[query, error_msg]] } # Update the CSS for better contrast and readability css = """ .gradio-container { max-width: 1200px !important; background-color: #f7f7f8 !important; } #header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; background: #1a1b1e; border-radius: 12px; color: white; } #header h1 { color: white; font-size: 2.5rem; margin-bottom: 0.5rem; } #header h3 { color: #a8a9ab; } .search-container { background: #1a1b1e; border-radius: 12px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); padding: 1rem; margin-bottom: 1rem; } .search-box { padding: 1rem; background: #2c2d30; border-radius: 8px; margin-bottom: 1rem; } /* Style the input textbox */ .search-box input[type="text"] { background: #3a3b3e !important; border: 1px solid #4a4b4e !important; color: white !important; border-radius: 8px !important; } .search-box input[type="text"]::placeholder { color: #a8a9ab !important; } /* Style the search button */ .search-box button { background: #2563eb !important; border: none !important; } /* Results area styling */ .results-container { background: #2c2d30; border-radius: 8px; padding: 1rem; margin-top: 1rem; } .answer-box { background: #3a3b3e; border-radius: 8px; padding: 1.5rem; color: white; margin-bottom: 1rem; } .answer-box p { color: #e5e7eb; line-height: 1.6; } .sources-container { margin-top: 1rem; background: #2c2d30; border-radius: 8px; padding: 1rem; } .source-item { display: flex; padding: 12px; margin: 8px 0; background: #3a3b3e; border-radius: 8px; transition: all 0.2s; } .source-item:hover { background: #4a4b4e; } .source-number { font-weight: bold; margin-right: 12px; color: #60a5fa; } .source-content { flex: 1; } .source-title { color: #60a5fa; font-weight: 500; text-decoration: none; display: block; margin-bottom: 4px; } .source-date { color: #a8a9ab; font-size: 0.9em; margin-left: 8px; } .source-snippet { color: #e5e7eb; font-size: 0.9em; line-height: 1.4; } .chat-history { max-height: 400px; overflow-y: auto; padding: 1rem; background: #2c2d30; border-radius: 8px; margin-top: 1rem; } .examples-container { background: #2c2d30; border-radius: 8px; padding: 1rem; margin-top: 1rem; } .examples-container button { background: #3a3b3e !important; border: 1px solid #4a4b4e !important; color: #e5e7eb !important; } /* Markdown content styling */ .markdown-content { color: #e5e7eb !important; } .markdown-content h1, .markdown-content h2, .markdown-content h3 { color: white !important; } .markdown-content a { color: #60a5fa !important; } /* Accordion styling */ .accordion { background: #2c2d30 !important; border-radius: 8px !important; margin-top: 1rem !important; } """ # Update the Gradio interface layout with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo: chat_history = gr.State([]) with gr.Column(elem_id="header"): gr.Markdown("# 🔍 AI Search Assistant") gr.Markdown("### Powered by DeepSeek & Real-time Web Results") with gr.Column(elem_classes="search-container"): with gr.Row(elem_classes="search-box"): search_input = gr.Textbox( label="", placeholder="Ask anything...", scale=5, container=False ) search_btn = gr.Button("Search", variant="primary", scale=1) with gr.Row(elem_classes="results-container"): with gr.Column(scale=2): with gr.Column(elem_classes="answer-box"): answer_output = gr.Markdown(elem_classes="markdown-content") with gr.Accordion("Chat History", open=False, elem_classes="accordion"): chat_history_display = gr.Chatbot(elem_classes="chat-history") with gr.Column(scale=1): with gr.Column(elem_classes="sources-box"): gr.Markdown("### Sources") sources_output = gr.HTML() with gr.Row(elem_classes="examples-container"): gr.Examples( examples=[ "What are the latest developments in quantum computing?", "Explain the impact of AI on healthcare", "What are the best practices for sustainable living?", "How is climate change affecting ocean ecosystems?" ], inputs=search_input, label="Try these examples" ) # Handle interactions search_btn.click( fn=process_query, inputs=[search_input, chat_history], outputs=[answer_output, sources_output, search_btn, chat_history_display] ) # Also trigger search on Enter key search_input.submit( fn=process_query, inputs=[search_input, chat_history], outputs=[answer_output, sources_output, search_btn, chat_history_display] ) if __name__ == "__main__": demo.launch(share=True)