PapersImpact / app.py
openfree's picture
Update app.py
822a9a7 verified
raw
history blame
14.7 kB
import gradio as gr
import spaces
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
import torch.nn as nn
import re
import requests
from urllib.parse import urlparse
import xml.etree.ElementTree as ET
# Model repository path and device selection
model_path = "ssocean/NAIP"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Global model/tokenizer variables
model = None
tokenizer = None
def fetch_arxiv_paper(arxiv_input):
"""
Fetch paper details (title, abstract) from an arXiv URL or ID using requests.
"""
try:
# If user passed a full arxiv.org link, parse out the ID
if "arxiv.org" in arxiv_input:
parsed = urlparse(arxiv_input)
path = parsed.path
arxiv_id = path.split("/")[-1].replace(".pdf", "")
else:
# Otherwise just use the raw ID
arxiv_id = arxiv_input.strip()
# ArXiv API query
api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
resp = requests.get(api_url)
if resp.status_code != 200:
return {
"title": "",
"abstract": "",
"success": False,
"message": "Error fetching paper from arXiv API"
}
# Parse XML response
root = ET.fromstring(resp.text)
ns = {"arxiv": "http://www.w3.org/2005/Atom"}
entry = root.find(".//arxiv:entry", ns)
if entry is None:
return {
"title": "",
"abstract": "",
"success": False,
"message": "Paper not found"
}
title = entry.find("arxiv:title", ns).text.strip()
abstract = entry.find("arxiv:summary", ns).text.strip()
return {
"title": title,
"abstract": abstract,
"success": True,
"message": "Paper fetched successfully!"
}
except Exception as e:
return {
"title": "",
"abstract": "",
"success": False,
"message": f"Error fetching paper: {e}"
}
@spaces.GPU(duration=60, enable_queue=True)
def predict(title, abstract):
"""
Predict a normalized academic impact score (0–1) given the paper title & abstract.
Loads the model once globally, then uses it for inference.
"""
global model, tokenizer
if model is None:
# Load model config, disable quantization, and set number of labels if needed
config = AutoConfig.from_pretrained(model_path)
config.quantization_config = None
config.num_labels = 1 # For classification/logit output
# IMPORTANT: Do not pass num_labels directly into from_pretrained for LLaMA-based models
model = AutoModelForSequenceClassification.from_pretrained(
model_path,
config=config,
torch_dtype=torch.float32, # Use full-precision float32
device_map=None, # We'll move it manually
low_cpu_mem_usage=False
)
model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path)
text = (
f"Given a certain paper,\n"
f"Title: {title.strip()}\n"
f"Abstract: {abstract.strip()}\n"
f"Predict its normalized academic impact (0~1):"
)
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model(**inputs)
logits = output.logits
prob = torch.sigmoid(logits).item()
score = min(1.0, prob + 0.05) # +0.05 offset, capped at 1.0
return round(score, 4)
except Exception as e:
print(f"Prediction error: {e}")
return 0.0 # Return 0 in case of any error
def get_grade_and_emoji(score):
"""Convert a 0–1 score into a tier grade with emoji indicator."""
if score >= 0.900: return "AAA 🌟"
if score >= 0.800: return "AA ⭐"
if score >= 0.650: return "A ✨"
if score >= 0.600: return "BBB πŸ”΅"
if score >= 0.550: return "BB πŸ“˜"
if score >= 0.500: return "B πŸ“–"
if score >= 0.400: return "CCC πŸ“"
if score >= 0.300: return "CC ✏️"
return "C πŸ“‘"
def validate_input(title, abstract):
"""
Ensure title >=3 words, abstract >=50 words, and only ASCII chars.
"""
non_ascii = re.compile(r"[^\x00-\x7F]")
if len(title.split()) < 3:
return False, "Title must be at least 3 words."
if len(abstract.split()) < 50:
return False, "Abstract must be at least 50 words."
if non_ascii.search(title):
return False, "Title contains non-ASCII characters."
if non_ascii.search(abstract):
return False, "Abstract contains non-ASCII characters."
return True, "Inputs look good."
def update_button_status(title, abstract):
valid, msg = validate_input(title, abstract)
if not valid:
return gr.update(value="Error: " + msg), gr.update(interactive=False)
return gr.update(value=msg), gr.update(interactive=True)
def process_arxiv_input(arxiv_input):
"""
Helper to fill in title/abstract fields from an arXiv link/ID.
"""
if not arxiv_input.strip():
return "", "", "Please enter an arXiv URL or ID"
result = fetch_arxiv_paper(arxiv_input)
if result["success"]:
return result["title"], result["abstract"], result["message"]
return "", "", result["message"]
# Custom CSS for styling
css = """
.gradio-container { font-family: Arial, sans-serif; }
.main-title {
text-align: center;
color: #2563eb;
font-size: 2.5rem !important;
margin-bottom: 1rem !important;
background: linear-gradient(45deg, #2563eb, #1d4ed8);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.sub-title {
text-align: center;
color: #4b5563;
font-size: 1.5rem !important;
margin-bottom: 2rem !important;
}
.input-section {
background: white;
padding: 2rem;
border-radius: 1rem;
box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);
}
.result-section {
background: #f8fafc;
padding: 2rem;
border-radius: 1rem;
margin-top: 2rem;
}
.methodology-section {
background: #ecfdf5;
padding: 2rem;
border-radius: 1rem;
margin-top: 2rem;
}
.example-section {
background: #fff7ed;
padding: 2rem;
border-radius: 1rem;
margin-top: 2rem;
}
.grade-display {
font-size: 3rem;
text-align: center;
margin: 1rem 0;
}
.arxiv-input {
margin-bottom: 1.5rem;
padding: 1rem;
background: #f3f4f6;
border-radius: 0.5rem;
}
.arxiv-link {
color: #2563eb;
text-decoration: underline;
font-size: 0.9em;
margin-top: 0.5em;
}
.arxiv-note {
color: #666;
font-size: 0.9em;
margin-top: 0.5em;
margin-bottom: 0.5em;
}
"""
# Example papers
example_papers = [
{
"title": "Attention Is All You Need",
"abstract": (
"The dominant sequence transduction models are based on complex recurrent or "
"convolutional neural networks that include an encoder and a decoder. The best performing "
"models also connect the encoder and decoder through an attention mechanism. We propose a "
"new simple network architecture, the Transformer, based solely on attention mechanisms, "
"dispensing with recurrence and convolutions entirely. Experiments on two machine "
"translation tasks show these models to be superior in quality while being more "
"parallelizable and requiring significantly less time to train."
),
"score": 0.982,
"note": "πŸ’« Revolutionary paper that introduced the Transformer architecture."
},
{
"title": "Language Models are Few-Shot Learners",
"abstract": (
"Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by "
"pre-training on a large corpus of text followed by fine-tuning on a specific task. While "
"typically task-agnostic in architecture, this method still requires task-specific "
"fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans "
"can generally perform a new language task from only a few examples or from simple "
"instructions - something which current NLP systems still largely struggle to do. Here we "
"show that scaling up language models greatly improves task-agnostic, few-shot "
"performance, sometimes even reaching competitiveness with prior state-of-the-art "
"fine-tuning approaches."
),
"score": 0.956,
"note": "πŸš€ Groundbreaking GPT-3 paper that demonstrated the power of large language models."
},
{
"title": "An Empirical Study of Neural Network Training Protocols",
"abstract": (
"This paper presents a comparative analysis of different training protocols for neural "
"networks across various architectures. We examine the effects of learning rate schedules, "
"batch size selection, and optimization algorithms on model convergence and final "
"performance. Our experiments span multiple datasets and model sizes, providing practical "
"insights for deep learning practitioners."
),
"score": 0.623,
"note": "πŸ“š Solid research paper with useful findings but more limited scope and impact."
}
]
# Build Gradio interface
with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
gr.Markdown(
"""
# Papers Impact: AI-Powered Research Impact Predictor
## https://discord.gg/openfreeai
"""
)
gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space">
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space&countColor=%23263759" />
</a>""")
with gr.Row():
with gr.Column(elem_classes="input-section"):
# arXiv import group
with gr.Group(elem_classes="arxiv-input"):
gr.Markdown("### πŸ“‘ Import from arXiv")
arxiv_input = gr.Textbox(
lines=1,
placeholder="Enter arXiv URL or ID (e.g., 2504.11651)",
label="arXiv Paper URL/ID",
value="2504.11651"
)
gr.Markdown(
"""
<p class="arxiv-note">
Click input field to use example paper or browse papers at
<a href="https://arxiv.org" target="_blank" class="arxiv-link">arxiv.org</a>
</p>
"""
)
fetch_button = gr.Button("πŸ” Fetch Paper Details", variant="secondary")
gr.Markdown("### πŸ“ Or Enter Paper Details Manually")
title_input = gr.Textbox(
lines=2,
placeholder="Enter Paper Title (minimum 3 words)...",
label="Paper Title"
)
abstract_input = gr.Textbox(
lines=5,
placeholder="Enter Paper Abstract (minimum 50 words)...",
label="Paper Abstract"
)
validation_status = gr.Textbox(label="βœ”οΈ Validation Status", interactive=False)
submit_button = gr.Button("🎯 Predict Impact", interactive=False, variant="primary")
with gr.Column(elem_classes="result-section"):
with gr.Group():
score_output = gr.Number(label="🎯 Impact Score")
grade_output = gr.Textbox(label="πŸ† Grade", value="", elem_classes="grade-display")
with gr.Row(elem_classes="methodology-section"):
gr.Markdown(
"""
### πŸ”¬ Scientific Methodology
- **Training Data**: Model trained on extensive dataset of published papers from CS.CV, CS.CL(NLP), and CS.AI fields
- **Optimization**: NDCG optimization with Sigmoid activation and MSE loss function
- **Validation**: Cross-validated against historical paper impact data
- **Architecture**: Advanced transformer-based deep textual analysis
- **Metrics**: Quantitative analysis of citation patterns and research influence
"""
)
with gr.Row():
gr.Markdown(
"""
### πŸ“Š Rating Scale
| Grade | Score Range | Description | Indicator |
|-------|-------------|-------------|-----------|
| AAA | 0.900-1.000 | Exceptional Impact | 🌟 |
| AA | 0.800-0.899 | Very High Impact | ⭐ |
| A | 0.650-0.799 | High Impact | ✨ |
| BBB | 0.600-0.649 | Above Average Impact | πŸ”΅ |
| BB | 0.550-0.599 | Moderate Impact | πŸ“˜ |
| B | 0.500-0.549 | Average Impact | πŸ“– |
| CCC | 0.400-0.499 | Below Average Impact | πŸ“ |
| CC | 0.300-0.399 | Low Impact | ✏️ |
| C | < 0.299 | Limited Impact | πŸ“‘ |
"""
)
with gr.Row(elem_classes="example-section"):
gr.Markdown("### πŸ“‹ Example Papers")
for paper in example_papers:
gr.Markdown(
f"""
#### {paper['title']}
**Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))}
{paper['abstract']}
*{paper['note']}*
---
"""
)
# Validate button status on input changes
title_input.change(
update_button_status,
inputs=[title_input, abstract_input],
outputs=[validation_status, submit_button]
)
abstract_input.change(
update_button_status,
inputs=[title_input, abstract_input],
outputs=[validation_status, submit_button]
)
# Fetch from arXiv
fetch_button.click(
process_arxiv_input,
inputs=[arxiv_input],
outputs=[title_input, abstract_input, validation_status]
)
# Predict callback
def process_prediction(title, abstract):
score = predict(title, abstract)
grade = get_grade_and_emoji(score)
return score, grade
submit_button.click(
process_prediction,
inputs=[title_input, abstract_input],
outputs=[score_output, grade_output]
)
if __name__ == "__main__":
iface.launch()