import json import os import faiss import gradio as gr import pandas as pd import spaces import torch from datasets import load_dataset from huggingface_hub import InferenceClient, hf_hub_download from huggingface_hub import login as hf_hub_login from huggingface_hub import upload_file from sentence_transformers import SentenceTransformer from arxiv_stuff import ARXIV_CATEGORIES_FLAT # Get HF_TOKEN from environment variables HF_TOKEN = os.getenv("HF_TOKEN") # Login to Hugging Face Hub hf_hub_login(token=HF_TOKEN, add_to_git_credential=True) # Dataset details dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings" dataset_revision = "v1.0.0" local_index_path = "arxiv_faiss_index.faiss" # Embedding model details embedding_model_name = "nomadicsynth/research-compass-arxiv-abstracts-embedding-model" embedding_model_revision = "2025-01-28_23-06-17-1epochs-12batch-32eval-512embed-final" # Amalysis model details # Settings for Llama-3.3-70B-Instruct reasoning_model_id = "meta-llama/Llama-3.3-70B-Instruct" max_length = 1024 * 4 temperature = None top_p = None presence_penalty = None # Settings for QwQ-32B # reasoning_model_id = "Qwen/QwQ-32B" # reasoning_start_tag = "" # reasoning_end_tag = "" # max_length = 1024 * 4 # temperature = 0.6 # top_p = 0.95 # presence_penalty = 0.1 # Global variables dataset = None embedding_model = None reasoning_model = None def save_faiss_index_to_hub(): """Save the FAISS index to the Hub for easy access""" global dataset, local_index_path # 1. Save the index to a local file dataset["train"].save_faiss_index("embedding", local_index_path) print(f"FAISS index saved locally to {local_index_path}") # 2. Upload the index file to the Hub remote_path = upload_file( path_or_fileobj=local_index_path, path_in_repo=local_index_path, # Same name on the Hub repo_id=dataset_name, # Use your dataset repo token=HF_TOKEN, repo_type="dataset", # This is a dataset file revision=dataset_revision, # Use the same revision as the dataset commit_message="Add FAISS index", # Commit message ) print(f"FAISS index uploaded to Hub at {remote_path}") # Remove the local file. It's now stored on the Hub. os.remove(local_index_path) def setup_dataset(): """Load dataset with FAISS index""" global dataset print("Loading dataset from Hugging Face...") # Load dataset dataset = load_dataset( dataset_name, revision=dataset_revision, ) # Try to load the index from the Hub try: print("Downloading pre-built FAISS index...") index_path = hf_hub_download( repo_id=dataset_name, filename="arxiv_faiss_index.faiss", revision=dataset_revision, token=HF_TOKEN, repo_type="dataset", ) print("Loading pre-built FAISS index...") dataset["train"].load_faiss_index("embedding", index_path) print("Pre-built FAISS index loaded successfully") except Exception as e: print(f"Could not load pre-built index: {e}") print("Building new FAISS index...") # Add FAISS index if it doesn't exist if not dataset["train"].features.get("embedding"): print("Dataset doesn't have 'embedding' column, cannot create FAISS index") raise ValueError("Dataset doesn't have 'embedding' column") dataset["train"].add_faiss_index( column="embedding", metric_type=faiss.METRIC_INNER_PRODUCT, string_factory="HNSW,RFlat", # Using reranking ) # Save the FAISS index to the Hub save_faiss_index_to_hub() print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready") def init_embedding_model(model_name_or_path: str, model_revision: str = None) -> SentenceTransformer: global embedding_model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") embedding_model = SentenceTransformer( model_name_or_path, revision=model_revision, token=HF_TOKEN, device=device, ) def init_reasoning_model(model_name: str) -> InferenceClient: global reasoning_model reasoning_model = InferenceClient( model=model_name, provider="hf-inference", api_key=HF_TOKEN, ) return reasoning_model def generate(messages: list[dict[str, str]]) -> str: """ Generate a response to a list of messages. Args: messages: A list of message dictionaries with a "role" and "content" key. Returns: The generated response as a string. """ global reasoning_model system_message = { "role": "system", "content": "You are an expert in evaluating connections between research papers.", } messages.insert(0, system_message) response_schema = r"""{ "$schema": "http://json-schema.org/draft-07/schema#", "title": "Generated schema for Root", "type": "object", "properties": { "reasoning": { "type": "string" }, "key_connections": { "type": "array", "items": { "type": "object", "properties": { "connection": { "type": "string" }, "description": { "type": "string" } }, "required": [ "connection", "description" ] } }, "synergies_and_complementarities": { "type": "array", "items": { "type": "object", "properties": { "type": { "type": "array", "items": { "type": "string" } }, "description": { "type": "string" } }, "required": [ "type", "description" ] } }, "research_potential": { "type": "array", "items": { "type": "object", "properties": { "potential": { "type": "string" }, "description": { "type": "string" } }, "required": [ "potential", "description" ] } }, "rating": { "type": "number" }, "confidence": { "type": "number" } }, "required": [ "reasoning", "key_connections", "synergies_and_complementarities", "research_potential", "rating", "confidence" ] }""" response_format = { "type": "json", "value": response_schema, } result = reasoning_model.chat.completions.create( messages=messages, max_tokens=max_length, temperature=temperature, presence_penalty=presence_penalty, response_format=response_format, top_p=top_p, ) output = result.choices[0].message.content.strip() return output @spaces.GPU def embed_text(text: str | list[str]) -> torch.Tensor: global embedding_model # Strip any leading/trailing whitespace text = text.strip() if isinstance(text, str) else [t.strip() for t in text] embed_text = embedding_model.encode(text, normalize_embeddings=True) # Ensure vectors are normalized return embed_text def analyse_abstracts(query_abstract: str, compare_abstract: dict) -> str: """Analyze the relationship between two abstracts and return formatted analysis""" # Highlight the synergies in thesede papers that would justify further research messages = [ { "role": "user", "content": f"""You are trained in evaluating connections between research papers. Please **identify and analyze the links** between these two papers: Paper 1 Abstract: {query_abstract} Paper 2 Abstract: {compare_abstract["abstract"]} Consider the following aspects in your evaluation: * **Methodological Cross-Pollination**: How do the methods or approaches from one paper **directly enhance or inform** the other? * **Principle or Mechanism Extension**: Do the papers **share underlying principles or mechanisms** that can be **combined or extended** to yield new insights? * **Interdisciplinary Connections**: Are there **clear opportunities** for interdisciplinary collaborations or knowledge transfer between the two papers? * **Solution or Application Bridge**: Can the solutions or applications presented in one paper be **directly adapted or integrated** with the other to create **novel, actionable outcomes**? Consider the connections in either direction, that is, from Paper 1 -> Paper 2, or vice versa, from Paper 2 -> Paper 1 Return a valid JSON object with this structure: {{ "reasoning": "Step-by-step analysis of the papers, highlighting **key established connections**, identified synergies, and **concrete complementarities**. Emphasize the most **critical, actionable insights** or **key takeaways** from the analysis using markdown bold.", # Main connecting concepts, methods, or principles "key_connections": [ {{ "connection": "connection 1", "description": "Brief description (1-2 sentences) for the **established connection**, explaining its **direct relevance** to the synergy analysis." }}, ... ], "synergies_and_complementarities": [ {{ "type": ["Methodological Cross-Pollination", "Principle or Mechanism Extension", "Interdisciplinary Connections", "Solution or Application Bridge"], # Choose only one type per entry, and only include relevant types to this analysis "description": "Brief explanation (1-2 sentences) of the **identified, concrete synergy** or **complementarity**, and a **specific, actionable example** to illustrate the concept." }}, ... ], # Novel, actionable outcomes or applications emerging from the synergies "research_potential": [ {{ "potential": "Actionable outcome or application 1", "description": "Brief description (1-2 sentences) of the **concrete potential outcome** or **application**, and a **specific scenario** to illustrate its **direct impact**." }}, ... ], "rating": 1-5, # Overall rating of the papers' synergy potential, where: # 1 = **No synergy or connection** (definitely no link between the papers) # 2 = **Low potential for synergy** (some vague or speculative connection, but highly uncertain) # 3 = **Plausible synergy potential** (some potential connections, but requiring further investigation to confirm) # 4 = **Established synergy with potential for growth** (clear connections with opportunities for further development) # 5 = **High established synergy with direct, clear opportunities** (strong, concrete links with immediate, actionable outcomes) "confidence": 0.0-1.0, # Confidence in your analysis, as a floating-point value representing the probability of your assessment being accurate }} Return only the JSON object, with double quotes around key names and all string values.""", }, ] # Generate analysis try: output = generate(messages) except Exception as e: return f"Error: {e}" # Parse the JSON output try: output = json.loads(output) except Exception as e: return f"Error: {e}" # Format the output as markdown for better display key_connections = "" synergies_and_complementarities = "" research_potential = "" if "key_connections" in output: for connection in output["key_connections"]: key_connections += f"- {connection['connection']}: {connection['description']}\n" if "synergies_and_complementarities" in output: for synergy in output["synergies_and_complementarities"]: synergies_and_complementarities += f"- {', '.join(synergy['type'])}: {synergy['description']}\n" if "research_potential" in output: for potential in output["research_potential"]: research_potential += f"- {potential['potential']}: {potential['description']}\n" formatted_output = f"""## Synergy Analysis **Rating**: {'★' * output['rating']}{'☆' * (5-output['rating'])} **Confidence**: {'★' * round(output['confidence'] * 5)}{'☆' * round((1-output['confidence']) * 5)} ### Key Connections {key_connections} ### Synergies and Complementarities {synergies_and_complementarities} ### Research Potential {research_potential} ### Reasoning {output['reasoning']} """ return formatted_output # return '```"""\n' + output + '\n"""```' # arXiv Embedding Dataset Details # DatasetDict({ # train: Dataset({ # features: ['id', 'submitter', 'authors', 'title', 'comments', 'journal-ref', 'doi', 'report-no', 'categories', 'license', 'abstract', 'update_date', 'embedding', 'timestamp', 'embedding_model'], # num_rows: 2689088 # }) # }) def find_synergistic_papers(abstract: str, limit=25) -> list[dict]: """Find papers synergistic with the given abstract using FAISS with cosine similarity""" global dataset # Generate embedding for the query abstract (normalized for cosine similarity) abstract_embedding = embed_text(abstract) # Search for similar papers using FAISS with inner product (cosine similarity for normalized vectors) scores, examples = dataset["train"].get_nearest_examples("embedding", abstract_embedding, k=limit) papers = [] for i in range(len(scores)): # With cosine similarity, higher scores are better (closer to 1) paper_dict = { "id": examples["id"][i], "title": examples["title"][i], "authors": examples["authors"][i], "categories": examples["categories"][i], "abstract": examples["abstract"][i], "update_date": examples["update_date"][i], "synergy_score": float(scores[i]), # Convert to float for serialization } papers.append(paper_dict) return papers def format_search_results(abstract: str) -> tuple[pd.DataFrame, list[dict]]: """Format search results as a DataFrame for display""" # Find papers synergistic with the given abstract papers = find_synergistic_papers(abstract) # Convert to DataFrame for display df = pd.DataFrame( [ { "Title": p["title"], "Authors": p["authors"][:50] + "..." if len(p["authors"]) > 50 else p["authors"], "Categories": p["categories"], "Date": p["update_date"], "Match Score": f"{int(p['synergy_score'] * 100)}%", "ID": p["id"], # Hidden column for reference } for p in papers ] ) return df, papers # Return both DataFrame and original data def format_paper_as_markdown(paper: dict) -> str: # Convert category codes to full names, handling unknown categories subjects = [] for subject in paper["categories"].split(): if subject in ARXIV_CATEGORIES_FLAT: subjects.append(ARXIV_CATEGORIES_FLAT[subject]) else: subjects.append(f"Unknown Category ({subject})") paper["title"] = paper["title"].replace("\n", " ").strip() paper["authors"] = paper["authors"].replace("\n", " ").strip() return f"""# {paper["title"]} ### {paper["authors"]} #### {', '.join(subjects)} | {paper["update_date"]} | **Score**: {int(paper['synergy_score'] * 100)}% **[arxiv:{paper["id"]}](https://arxiv.org/abs/{paper["id"]})** - [PDF](https://arxiv.org/pdf/{paper["id"]})
{paper["abstract"]} """ latex_delimiters = [ {"left": "$$", "right": "$$", "display": True}, # {"left": "$", "right": "$", "display": False}, # {"left": "\\(", "right": "\\)", "display": False}, # {"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}, # {"left": "\\begin{align}", "right": "\\end{align}", "display": True}, # {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True}, # {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True}, # {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True}, # {"left": "\\[", "right": "\\]", "display": True}, # {"left": "\\underline{", "right": "}", "display": False}, # {"left": "\\textit{", "right": "}", "display": False}, # {"left": "\\textit{", "right": "}", "display": False}, # {"left": "{", "right": "}", "display": False}, ] def create_interface(): with gr.Blocks( css=""" .cell-menu-button { display: none; }""" ) as demo: gr.HTML( """

Research Compass

Find synergistic papers to enrich your research

An experiment in AI-driven research synergy analysis

""" ) with gr.Accordion(label="Instructions", open=False): gr.Markdown( """ 1. **Enter Abstract**: Paste an abstract or describe your research details in the text box. 2. **Search for Synergistic Papers**: Click the button to find papers with similar themes. 3. **Select a Paper**: Click on a row in the results table to view paper details. 4. **Analyze Connection Potential**: Click the button to analyze the synergy potential between the papers. 5. **Synergy Analysis**: View the detailed analysis of the connection potential between the papers. """ ) abstract_input = gr.Textbox( label="Paper Abstract or Description", placeholder="Paste an abstract or describe research details...", lines=8, key="abstract", ) search_btn = gr.Button("Search for Synergistic Papers", variant="primary") # Store full paper data paper_data_state = gr.State([]) # Store query abstract query_abstract_state = gr.State("") # Store selected paper selected_paper_state = gr.State(None) # Use Dataframe for results results_df = gr.Dataframe( headers=["Title", "Authors", "Categories", "Date", "Match Score"], datatype=["markdown", "markdown", "str", "date", "str"], latex_delimiters=latex_delimiters, label="Synergistic Papers", interactive=False, wrap=False, line_breaks=False, column_widths=["40%", "20%", "20%", "10%", "10%", "0%"], # Hide ID column key="results", ) with gr.Row(): with gr.Column(scale=1): paper_details_output = gr.Markdown( value="# Paper Details", label="Paper Details", latex_delimiters=latex_delimiters, show_copy_button=True, key="paper_details", ) analyze_btn = gr.Button("Analyze Connection Potential", variant="primary", interactive=False) with gr.Column(scale=1): # Analysis output analysis_output = gr.Markdown( value="# Synergy Analysis", label="Synergy Analysis", latex_delimiters=latex_delimiters, show_copy_button=True, key="analysis_output", ) # Display paper details when row is selected def on_select(evt: gr.SelectData, papers, query): selected_index = evt.index[0] # Get the row index selected = papers[selected_index] # Format paper details details_md = format_paper_as_markdown(selected) return details_md, selected # Connect search button to the search function search_btn.click( format_search_results, inputs=[abstract_input], outputs=[results_df, paper_data_state], api_name=False, ).then( lambda x: x, # Identity function to pass through the abstract inputs=[abstract_input], outputs=[query_abstract_state], api_name=False, ).then( lambda: None, # Reset selected paper outputs=[selected_paper_state], api_name=False, ).then( lambda: gr.update(interactive=False), # Disable analyze button until paper selected outputs=[analyze_btn], api_name=False, ).then( lambda: "# Synergy Analysis", # Clear previous analysis outputs=[analysis_output], api_name=False, ) # Use built-in select event from Dataframe results_df.select( on_select, inputs=[paper_data_state, query_abstract_state], outputs=[paper_details_output, selected_paper_state], api_name=False, ).then( lambda: gr.update(interactive=True), # Enable analyze button when paper selected outputs=[analyze_btn], api_name=False, ) # Connect analyze button to run analysis analyze_btn.click( analyse_abstracts, inputs=[query_abstract_state, selected_paper_state], outputs=[analysis_output], show_progress_on=[paper_details_output, analysis_output], api_name=False, ) return demo if __name__ == "__main__": # Load dataset with FAISS index setup_dataset() # Initialize the embedding model init_embedding_model(embedding_model_name, embedding_model_revision) # Initialize the reasoning model reasoning_model = init_reasoning_model(reasoning_model_id) demo = create_interface() demo.queue(api_open=False).launch(ssr_mode=False, show_api=False)