Spaces:
Running
Running
| import gradio as gr | |
| import google.generativeai as genai | |
| from google.generativeai.types import generation_types | |
| from ragatouille import RAGPretrainedModel | |
| import arxiv | |
| import os | |
| import re | |
| from datetime import datetime | |
| from utils import get_md_text_abstract | |
| from huggingface_hub import snapshot_download | |
| # --- Core Configuration --- | |
| hf_token = os.getenv("HF_TOKEN") | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| RAG_SOURCE = os.getenv("RAG_SOURCE") | |
| LOCAL_DATA_DIR = './rag_index_data' | |
| LLM_MODELS_TO_CHOOSE = [ | |
| 'google/gemma-3-4b-it', | |
| 'google/gemma-3-12b-it', | |
| 'google/gemma-3-27b-it', | |
| 'None' | |
| ] | |
| DEFAULT_LLM_MODEL = 'google/gemma-3-4b-it' | |
| RETRIEVE_RESULTS = 20 | |
| # --- Gemini API Configuration --- | |
| if gemini_api_key: | |
| genai.configure(api_key=gemini_api_key) | |
| else: | |
| print("CRITICAL WARNING: GEMINI_API_KEY environment variable not set. The application will not function without it.") | |
| GEMINI_GENERATION_CONFIG = genai.types.GenerationConfig( | |
| temperature=0.2, | |
| max_output_tokens=450, | |
| top_p=0.95, | |
| ) | |
| # --- RAG & Data Source Setup --- | |
| try: | |
| gr.Info("Setting up the RAG retriever...") | |
| # If the local index directory doesn't exist, download it from Hugging Face. | |
| if not os.path.exists(LOCAL_DATA_DIR): | |
| if not RAG_SOURCE or not hf_token: | |
| raise ValueError("RAG index not found locally, and RAG_SOURCE or HF_TOKEN environment variables are not set. Cannot download index.") | |
| snapshot_download( | |
| repo_id=RAG_SOURCE, | |
| repo_type="dataset", # Your index is stored as a dataset repo | |
| token=hf_token, | |
| local_dir=LOCAL_DATA_DIR | |
| ) | |
| gr.Info("Index downloaded successfully.") | |
| else: | |
| gr.Info(f"Found existing local index at {LOCAL_DATA_DIR}.") | |
| # Load the RAG model from the (now existing) local index path. | |
| gr.Info(f'''Loading index from {os.path.join(LOCAL_DATA_DIR, "arxiv_colbert")}...''') | |
| RAG = RAGPretrainedModel.from_index(os.path.join(LOCAL_DATA_DIR, "arxiv_colbert")) | |
| RAG.search("Test query", k=1) # Warm-up query | |
| gr.Info("Retriever loaded successfully!") | |
| except Exception as e: | |
| gr.Warning(f"Could not initialize the RAG retriever. The app may not function correctly. Error: {e}") | |
| RAG = None | |
| # --- UI Text and Metadata --- | |
| MARKDOWN_SEARCH_RESULTS_HEADER = '# π Search Results\n' | |
| APP_HEADER_TEXT = "# ArXiv CS RAG\n" | |
| INDEX_INFO = "Semantic Search" | |
| try: | |
| with open("README.md", "r") as f: | |
| mdfile = f.read() | |
| date_match = re.search(r'Index Last Updated : (\d{4}-\d{2}-\d{2})', mdfile) | |
| if date_match: | |
| date = date_match.group(1) | |
| formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y') | |
| APP_HEADER_TEXT += f'Index Last Updated: {formatted_date}\n' | |
| INDEX_INFO = f"Semantic Search - up to {formatted_date}" | |
| except Exception: | |
| print("README.md not found or is invalid. Using default data source info.") | |
| DATABASE_CHOICES = [INDEX_INFO, 'Arxiv Search - Latest - (EXPERIMENTAL)'] | |
| ARX_CLIENT = arxiv.Client() | |
| # --- Helper Functions --- | |
| def get_prompt_text(question, context): | |
| """Formats the prompt for the Gemma 3 model.""" | |
| system_instruction = ( | |
| "Based on the provided scientific paper abstracts, provide a comprehensive answer of 6-7 lines. " | |
| "Synthesize information from multiple sources if possible. Your answer must be grounded in the " | |
| "details found in the abstracts. Cite the titles of the papers you use as sources in your answer." | |
| ) | |
| message = f"Abstracts:\n{context}\n\nQuestion: {question}" | |
| return f"{system_instruction}\n\n{message}" | |
| def update_with_rag_md(message, llm_results_use, database_choice): | |
| """Fetches documents, updates the UI, and creates the final prompt for the LLM.""" | |
| prompt_context = "" | |
| rag_out = [] | |
| source_used = database_choice | |
| try: | |
| if database_choice == INDEX_INFO and RAG: | |
| rag_out = RAG.search(message, k=RETRIEVE_RESULTS) | |
| else: | |
| rag_out = list(ARX_CLIENT.results(arxiv.Search(query=message, max_results=RETRIEVE_RESULTS, sort_by=arxiv.SortCriterion.Relevance))) | |
| if not rag_out: | |
| gr.Warning("Live Arxiv search returned no results. Falling back to semantic search.") | |
| if RAG: | |
| rag_out = RAG.search(message, k=RETRIEVE_RESULTS) | |
| source_used = INDEX_INFO | |
| except Exception as e: | |
| gr.Warning(f"An error occurred during search: {e}. Falling back to semantic search.") | |
| if RAG: | |
| rag_out = RAG.search(message, k=RETRIEVE_RESULTS) | |
| source_used = INDEX_INFO | |
| md_text_updated = MARKDOWN_SEARCH_RESULTS_HEADER | |
| for i, rag_answer in enumerate(rag_out): | |
| md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=source_used, return_prompt_formatting=True) | |
| if i < llm_results_use: | |
| prompt_context += f"{i+1}. {prompt_text}\n" | |
| md_text_updated += md_text_paper | |
| final_prompt = get_prompt_text(message, prompt_context) | |
| return md_text_updated, final_prompt | |
| def ask_gemma_llm(prompt, llm_model_picked, stream_outputs): | |
| """Sends a prompt to the Google Gemini API and streams the response.""" | |
| if not prompt or not prompt.strip(): | |
| yield "Error: The generated prompt is empty. Please try a different query." | |
| return | |
| if llm_model_picked == 'None': | |
| yield "LLM Model is disabled." | |
| return | |
| if not gemini_api_key: | |
| yield "Error: GEMINI_API_KEY is not configured. Cannot contact the LLM." | |
| return | |
| try: | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, | |
| ] | |
| gemini_model_name = llm_model_picked.split('/')[-1] | |
| model = genai.GenerativeModel(gemini_model_name) | |
| response = model.generate_content( | |
| prompt, | |
| generation_config=GEMINI_GENERATION_CONFIG, | |
| stream=stream_outputs, | |
| safety_settings=safety_settings | |
| ) | |
| if stream_outputs: | |
| output = "" | |
| for chunk in response: | |
| try: | |
| text = chunk.parts[0].text | |
| output += text | |
| yield output | |
| except (IndexError, AttributeError): | |
| # Ignore empty chunks, which can occur at the end of a stream. | |
| pass | |
| if not output: | |
| yield "Model returned an empty or blocked stream. This may be due to the safety settings or the nature of the prompt." | |
| else: | |
| # Handle non-streaming responses. | |
| try: | |
| yield response.parts[0].text | |
| except (IndexError, AttributeError): | |
| reason = "UNKNOWN" | |
| if response.prompt_feedback.block_reason: | |
| reason = response.prompt_feedback.block_reason.name | |
| elif response.candidates and not response.candidates[0].content.parts: | |
| reason = response.candidates[0].finish_reason.name | |
| yield f"Model returned an empty or blocked response." | |
| except Exception as e: | |
| error_message = f"An error occurred with the Gemini API: {e}" | |
| print(error_message) # Server side log | |
| gr.Warning("An error occurred with the Gemini API. Check the server logs for details.") | |
| yield error_message | |
| # --- Gradio User Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(APP_HEADER_TEXT) | |
| with gr.Group(): | |
| msg = gr.Textbox(label='Search', placeholder='e.g., What is Mixtral?') | |
| with gr.Accordion("Advanced Settings", open=False): | |
| llm_model = gr.Dropdown(choices=LLM_MODELS_TO_CHOOSE, value=DEFAULT_LLM_MODEL, label='LLM Model') | |
| llm_results = gr.Slider(5, 20, value=10, step=1, label="Top n results as context") | |
| database_src = gr.Dropdown(choices=DATABASE_CHOICES, value=INDEX_INFO, label='Search Source') | |
| stream_results = gr.Checkbox(value=True, label="Stream output") | |
| output_text = gr.Textbox(label='LLM Answer', placeholder="The model's answer will appear here...", interactive=False, lines=8) | |
| input_prompt = gr.Textbox(visible=False) | |
| gr_md = gr.Markdown(MARKDOWN_SEARCH_RESULTS_HEADER) | |
| msg.submit( | |
| fn=update_with_rag_md, | |
| inputs=[msg, llm_results, database_src], | |
| outputs=[gr_md, input_prompt] | |
| ).then( | |
| fn=ask_gemma_llm, | |
| inputs=[input_prompt, llm_model, stream_results], | |
| outputs=[output_text] | |
| ) | |
| if __name__ == "__main__": | |
| # Launch the app | |
| demo.queue().launch() | |