import json
import os
import faiss
import gradio as gr
import pandas as pd
import spaces
import torch
from datasets import load_dataset
from huggingface_hub import InferenceClient, hf_hub_download
from huggingface_hub import login as hf_hub_login
from huggingface_hub import upload_file
from sentence_transformers import SentenceTransformer
from arxiv_stuff import ARXIV_CATEGORIES_FLAT
# Get HF_TOKEN from environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
# Login to Hugging Face Hub
hf_hub_login(token=HF_TOKEN, add_to_git_credential=True)
# Dataset details
dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
dataset_revision = "v1.0.0"
local_index_path = "arxiv_faiss_index.faiss"
# Embedding model details
embedding_model_name = "nomadicsynth/research-compass-arxiv-abstracts-embedding-model"
embedding_model_revision = "2025-01-28_23-06-17-1epochs-12batch-32eval-512embed-final"
# Amalysis model details
# Settings for Llama-3.3-70B-Instruct
reasoning_model_id = "meta-llama/Llama-3.3-70B-Instruct"
max_length = 1024 * 4
temperature = None
top_p = None
presence_penalty = None
# Settings for QwQ-32B
# reasoning_model_id = "Qwen/QwQ-32B"
# reasoning_start_tag = ""
# reasoning_end_tag = ""
# max_length = 1024 * 4
# temperature = 0.6
# top_p = 0.95
# presence_penalty = 0.1
# Global variables
dataset = None
embedding_model = None
reasoning_model = None
def save_faiss_index_to_hub():
"""Save the FAISS index to the Hub for easy access"""
global dataset, local_index_path
# 1. Save the index to a local file
dataset["train"].save_faiss_index("embedding", local_index_path)
print(f"FAISS index saved locally to {local_index_path}")
# 2. Upload the index file to the Hub
remote_path = upload_file(
path_or_fileobj=local_index_path,
path_in_repo=local_index_path, # Same name on the Hub
repo_id=dataset_name, # Use your dataset repo
token=HF_TOKEN,
repo_type="dataset", # This is a dataset file
revision=dataset_revision, # Use the same revision as the dataset
commit_message="Add FAISS index", # Commit message
)
print(f"FAISS index uploaded to Hub at {remote_path}")
# Remove the local file. It's now stored on the Hub.
os.remove(local_index_path)
def setup_dataset():
"""Load dataset with FAISS index"""
global dataset
print("Loading dataset from Hugging Face...")
# Load dataset
dataset = load_dataset(
dataset_name,
revision=dataset_revision,
)
# Try to load the index from the Hub
try:
print("Downloading pre-built FAISS index...")
index_path = hf_hub_download(
repo_id=dataset_name,
filename="arxiv_faiss_index.faiss",
revision=dataset_revision,
token=HF_TOKEN,
repo_type="dataset",
)
print("Loading pre-built FAISS index...")
dataset["train"].load_faiss_index("embedding", index_path)
print("Pre-built FAISS index loaded successfully")
except Exception as e:
print(f"Could not load pre-built index: {e}")
print("Building new FAISS index...")
# Add FAISS index if it doesn't exist
if not dataset["train"].features.get("embedding"):
print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
raise ValueError("Dataset doesn't have 'embedding' column")
dataset["train"].add_faiss_index(
column="embedding",
metric_type=faiss.METRIC_INNER_PRODUCT,
string_factory="HNSW,RFlat", # Using reranking
)
# Save the FAISS index to the Hub
save_faiss_index_to_hub()
print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")
def init_embedding_model(model_name_or_path: str, model_revision: str = None) -> SentenceTransformer:
global embedding_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = SentenceTransformer(
model_name_or_path,
revision=model_revision,
token=HF_TOKEN,
device=device,
)
def init_reasoning_model(model_name: str) -> InferenceClient:
global reasoning_model
reasoning_model = InferenceClient(
model=model_name,
provider="hf-inference",
api_key=HF_TOKEN,
)
return reasoning_model
def generate(messages: list[dict[str, str]]) -> str:
"""
Generate a response to a list of messages.
Args:
messages: A list of message dictionaries with a "role" and "content" key.
Returns:
The generated response as a string.
"""
global reasoning_model
system_message = {
"role": "system",
"content": "You are an expert in evaluating connections between research papers.",
}
messages.insert(0, system_message)
response_schema = r"""{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Generated schema for Root",
"type": "object",
"properties": {
"reasoning": {
"type": "string"
},
"key_connections": {
"type": "array",
"items": {
"type": "object",
"properties": {
"connection": {
"type": "string"
},
"description": {
"type": "string"
}
},
"required": [
"connection",
"description"
]
}
},
"synergies_and_complementarities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "array",
"items": {
"type": "string"
}
},
"description": {
"type": "string"
}
},
"required": [
"type",
"description"
]
}
},
"research_potential": {
"type": "array",
"items": {
"type": "object",
"properties": {
"potential": {
"type": "string"
},
"description": {
"type": "string"
}
},
"required": [
"potential",
"description"
]
}
},
"rating": {
"type": "number"
},
"confidence": {
"type": "number"
}
},
"required": [
"reasoning",
"key_connections",
"synergies_and_complementarities",
"research_potential",
"rating",
"confidence"
]
}"""
response_format = {
"type": "json",
"value": response_schema,
}
result = reasoning_model.chat.completions.create(
messages=messages,
max_tokens=max_length,
temperature=temperature,
presence_penalty=presence_penalty,
response_format=response_format,
top_p=top_p,
)
output = result.choices[0].message.content.strip()
return output
@spaces.GPU
def embed_text(text: str | list[str]) -> torch.Tensor:
global embedding_model
# Strip any leading/trailing whitespace
text = text.strip() if isinstance(text, str) else [t.strip() for t in text]
embed_text = embedding_model.encode(text, normalize_embeddings=True) # Ensure vectors are normalized
return embed_text
def analyse_abstracts(query_abstract: str, compare_abstract: dict) -> str:
"""Analyze the relationship between two abstracts and return formatted analysis"""
# Highlight the synergies in thesede papers that would justify further research
messages = [
{
"role": "user",
"content": f"""You are trained in evaluating connections between research papers. Please **identify and analyze the links** between these two papers:
Paper 1 Abstract:
{query_abstract}
Paper 2 Abstract:
{compare_abstract["abstract"]}
Consider the following aspects in your evaluation:
* **Methodological Cross-Pollination**: How do the methods or approaches from one paper **directly enhance or inform** the other?
* **Principle or Mechanism Extension**: Do the papers **share underlying principles or mechanisms** that can be **combined or extended** to yield new insights?
* **Interdisciplinary Connections**: Are there **clear opportunities** for interdisciplinary collaborations or knowledge transfer between the two papers?
* **Solution or Application Bridge**: Can the solutions or applications presented in one paper be **directly adapted or integrated** with the other to create **novel, actionable outcomes**?
Consider the connections in either direction, that is, from Paper 1 -> Paper 2, or vice versa, from Paper 2 -> Paper 1
Return a valid JSON object with this structure:
{{
"reasoning": "Step-by-step analysis of the papers, highlighting **key established connections**, identified synergies, and **concrete complementarities**. Emphasize the most **critical, actionable insights** or **key takeaways** from the analysis using markdown bold.",
# Main connecting concepts, methods, or principles
"key_connections": [
{{
"connection": "connection 1",
"description": "Brief description (1-2 sentences) for the **established connection**, explaining its **direct relevance** to the synergy analysis."
}},
...
],
"synergies_and_complementarities": [
{{
"type": ["Methodological Cross-Pollination", "Principle or Mechanism Extension", "Interdisciplinary Connections", "Solution or Application Bridge"], # Choose only one type per entry, and only include relevant types to this analysis
"description": "Brief explanation (1-2 sentences) of the **identified, concrete synergy** or **complementarity**, and a **specific, actionable example** to illustrate the concept."
}},
...
],
# Novel, actionable outcomes or applications emerging from the synergies
"research_potential": [
{{
"potential": "Actionable outcome or application 1",
"description": "Brief description (1-2 sentences) of the **concrete potential outcome** or **application**, and a **specific scenario** to illustrate its **direct impact**."
}},
...
],
"rating": 1-5, # Overall rating of the papers' synergy potential, where:
# 1 = **No synergy or connection** (definitely no link between the papers)
# 2 = **Low potential for synergy** (some vague or speculative connection, but highly uncertain)
# 3 = **Plausible synergy potential** (some potential connections, but requiring further investigation to confirm)
# 4 = **Established synergy with potential for growth** (clear connections with opportunities for further development)
# 5 = **High established synergy with direct, clear opportunities** (strong, concrete links with immediate, actionable outcomes)
"confidence": 0.0-1.0, # Confidence in your analysis, as a floating-point value representing the probability of your assessment being accurate
}}
Return only the JSON object, with double quotes around key names and all string values.""",
},
]
# Generate analysis
try:
output = generate(messages)
except Exception as e:
return f"Error: {e}"
# Parse the JSON output
try:
output = json.loads(output)
except Exception as e:
return f"Error: {e}"
# Format the output as markdown for better display
key_connections = ""
synergies_and_complementarities = ""
research_potential = ""
if "key_connections" in output:
for connection in output["key_connections"]:
key_connections += f"- {connection['connection']}: {connection['description']}\n"
if "synergies_and_complementarities" in output:
for synergy in output["synergies_and_complementarities"]:
synergies_and_complementarities += f"- {', '.join(synergy['type'])}: {synergy['description']}\n"
if "research_potential" in output:
for potential in output["research_potential"]:
research_potential += f"- {potential['potential']}: {potential['description']}\n"
formatted_output = f"""## Synergy Analysis
**Rating**: {'★' * output['rating']}{'☆' * (5-output['rating'])} **Confidence**: {'★' * round(output['confidence'] * 5)}{'☆' * round((1-output['confidence']) * 5)}
### Key Connections
{key_connections}
### Synergies and Complementarities
{synergies_and_complementarities}
### Research Potential
{research_potential}
### Reasoning
{output['reasoning']}
"""
return formatted_output
# return '```"""\n' + output + '\n"""```'
# arXiv Embedding Dataset Details
# DatasetDict({
# train: Dataset({
# features: ['id', 'submitter', 'authors', 'title', 'comments', 'journal-ref', 'doi', 'report-no', 'categories', 'license', 'abstract', 'update_date', 'embedding', 'timestamp', 'embedding_model'],
# num_rows: 2689088
# })
# })
def find_synergistic_papers(abstract: str, limit=25) -> list[dict]:
"""Find papers synergistic with the given abstract using FAISS with cosine similarity"""
global dataset
# Generate embedding for the query abstract (normalized for cosine similarity)
abstract_embedding = embed_text(abstract)
# Search for similar papers using FAISS with inner product (cosine similarity for normalized vectors)
scores, examples = dataset["train"].get_nearest_examples("embedding", abstract_embedding, k=limit)
papers = []
for i in range(len(scores)):
# With cosine similarity, higher scores are better (closer to 1)
paper_dict = {
"id": examples["id"][i],
"title": examples["title"][i],
"authors": examples["authors"][i],
"categories": examples["categories"][i],
"abstract": examples["abstract"][i],
"update_date": examples["update_date"][i],
"synergy_score": float(scores[i]), # Convert to float for serialization
}
papers.append(paper_dict)
return papers
def format_search_results(abstract: str) -> tuple[pd.DataFrame, list[dict]]:
"""Format search results as a DataFrame for display"""
# Find papers synergistic with the given abstract
papers = find_synergistic_papers(abstract)
# Convert to DataFrame for display
df = pd.DataFrame(
[
{
"Title": p["title"],
"Authors": p["authors"][:50] + "..." if len(p["authors"]) > 50 else p["authors"],
"Categories": p["categories"],
"Date": p["update_date"],
"Match Score": f"{int(p['synergy_score'] * 100)}%",
"ID": p["id"], # Hidden column for reference
}
for p in papers
]
)
return df, papers # Return both DataFrame and original data
def format_paper_as_markdown(paper: dict) -> str:
# Convert category codes to full names, handling unknown categories
subjects = []
for subject in paper["categories"].split():
if subject in ARXIV_CATEGORIES_FLAT:
subjects.append(ARXIV_CATEGORIES_FLAT[subject])
else:
subjects.append(f"Unknown Category ({subject})")
paper["title"] = paper["title"].replace("\n", " ").strip()
paper["authors"] = paper["authors"].replace("\n", " ").strip()
return f"""# {paper["title"]}
### {paper["authors"]}
#### {', '.join(subjects)} | {paper["update_date"]} | **Score**: {int(paper['synergy_score'] * 100)}%
**[arxiv:{paper["id"]}](https://arxiv.org/abs/{paper["id"]})** - [PDF](https://arxiv.org/pdf/{paper["id"]})
{paper["abstract"]}
"""
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
# {"left": "$", "right": "$", "display": False},
# {"left": "\\(", "right": "\\)", "display": False},
# {"left": "\\begin{equation}", "right": "\\end{equation}", "display": True},
# {"left": "\\begin{align}", "right": "\\end{align}", "display": True},
# {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True},
# {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True},
# {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True},
# {"left": "\\[", "right": "\\]", "display": True},
# {"left": "\\underline{", "right": "}", "display": False},
# {"left": "\\textit{", "right": "}", "display": False},
# {"left": "\\textit{", "right": "}", "display": False},
# {"left": "{", "right": "}", "display": False},
]
def create_interface():
with gr.Blocks(
css="""
.cell-menu-button {
display: none;
}"""
) as demo:
gr.HTML(
"""
Research Compass
Find synergistic papers to enrich your research
An experiment in AI-driven research synergy analysis
"""
)
with gr.Accordion(label="Instructions", open=False):
gr.Markdown(
"""
1. **Enter Abstract**: Paste an abstract or describe your research details in the text box.
2. **Search for Synergistic Papers**: Click the button to find papers with similar themes.
3. **Select a Paper**: Click on a row in the results table to view paper details.
4. **Analyze Connection Potential**: Click the button to analyze the synergy potential between the papers.
5. **Synergy Analysis**: View the detailed analysis of the connection potential between the papers.
"""
)
abstract_input = gr.Textbox(
label="Paper Abstract or Description",
placeholder="Paste an abstract or describe research details...",
lines=8,
key="abstract",
)
search_btn = gr.Button("Search for Synergistic Papers", variant="primary")
# Store full paper data
paper_data_state = gr.State([])
# Store query abstract
query_abstract_state = gr.State("")
# Store selected paper
selected_paper_state = gr.State(None)
# Use Dataframe for results
results_df = gr.Dataframe(
headers=["Title", "Authors", "Categories", "Date", "Match Score"],
datatype=["markdown", "markdown", "str", "date", "str"],
latex_delimiters=latex_delimiters,
label="Synergistic Papers",
interactive=False,
wrap=False,
line_breaks=False,
column_widths=["40%", "20%", "20%", "10%", "10%", "0%"], # Hide ID column
key="results",
)
with gr.Row():
with gr.Column(scale=1):
paper_details_output = gr.Markdown(
value="# Paper Details",
label="Paper Details",
latex_delimiters=latex_delimiters,
show_copy_button=True,
key="paper_details",
)
analyze_btn = gr.Button("Analyze Connection Potential", variant="primary", interactive=False)
with gr.Column(scale=1):
# Analysis output
analysis_output = gr.Markdown(
value="# Synergy Analysis",
label="Synergy Analysis",
latex_delimiters=latex_delimiters,
show_copy_button=True,
key="analysis_output",
)
# Display paper details when row is selected
def on_select(evt: gr.SelectData, papers, query):
selected_index = evt.index[0] # Get the row index
selected = papers[selected_index]
# Format paper details
details_md = format_paper_as_markdown(selected)
return details_md, selected
# Connect search button to the search function
search_btn.click(
format_search_results,
inputs=[abstract_input],
outputs=[results_df, paper_data_state],
api_name=False,
).then(
lambda x: x, # Identity function to pass through the abstract
inputs=[abstract_input],
outputs=[query_abstract_state],
api_name=False,
).then(
lambda: None, # Reset selected paper
outputs=[selected_paper_state],
api_name=False,
).then(
lambda: gr.update(interactive=False), # Disable analyze button until paper selected
outputs=[analyze_btn],
api_name=False,
).then(
lambda: "# Synergy Analysis", # Clear previous analysis
outputs=[analysis_output],
api_name=False,
)
# Use built-in select event from Dataframe
results_df.select(
on_select,
inputs=[paper_data_state, query_abstract_state],
outputs=[paper_details_output, selected_paper_state],
api_name=False,
).then(
lambda: gr.update(interactive=True), # Enable analyze button when paper selected
outputs=[analyze_btn],
api_name=False,
)
# Connect analyze button to run analysis
analyze_btn.click(
analyse_abstracts,
inputs=[query_abstract_state, selected_paper_state],
outputs=[analysis_output],
show_progress_on=[paper_details_output, analysis_output],
api_name=False,
)
return demo
if __name__ == "__main__":
# Load dataset with FAISS index
setup_dataset()
# Initialize the embedding model
init_embedding_model(embedding_model_name, embedding_model_revision)
# Initialize the reasoning model
reasoning_model = init_reasoning_model(reasoning_model_id)
demo = create_interface()
demo.queue(api_open=False).launch(ssr_mode=False, show_api=False)