Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import re | |
import requests | |
from urllib.parse import urlparse | |
import xml.etree.ElementTree as ET | |
# Model repository path and device selection | |
model_path = "ssocean/NAIP" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Global model/tokenizer variables | |
model = None | |
tokenizer = None | |
def fetch_arxiv_paper(arxiv_input): | |
""" | |
Fetch paper details (title, abstract) from an arXiv URL or ID using requests. | |
""" | |
try: | |
# If user passed a full arxiv.org link, parse out the ID | |
if "arxiv.org" in arxiv_input: | |
parsed = urlparse(arxiv_input) | |
path = parsed.path | |
arxiv_id = path.split("/")[-1].replace(".pdf", "") | |
else: | |
# Otherwise just use the raw ID | |
arxiv_id = arxiv_input.strip() | |
# ArXiv API query | |
api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" | |
resp = requests.get(api_url) | |
if resp.status_code != 200: | |
return { | |
"title": "", | |
"abstract": "", | |
"success": False, | |
"message": "Error fetching paper from arXiv API" | |
} | |
# Parse XML response | |
root = ET.fromstring(resp.text) | |
ns = {"arxiv": "http://www.w3.org/2005/Atom"} | |
entry = root.find(".//arxiv:entry", ns) | |
if entry is None: | |
return { | |
"title": "", | |
"abstract": "", | |
"success": False, | |
"message": "Paper not found" | |
} | |
title = entry.find("arxiv:title", ns).text.strip() | |
abstract = entry.find("arxiv:summary", ns).text.strip() | |
return { | |
"title": title, | |
"abstract": abstract, | |
"success": True, | |
"message": "Paper fetched successfully!" | |
} | |
except Exception as e: | |
return { | |
"title": "", | |
"abstract": "", | |
"success": False, | |
"message": f"Error fetching paper: {e}" | |
} | |
def predict(title, abstract): | |
""" | |
Predict a normalized academic impact score (0β1) given the paper title & abstract. | |
Loads the model once globally, then uses it for inference. | |
""" | |
global model, tokenizer | |
if model is None: | |
# Load model config, disable quantization, and set number of labels if needed | |
config = AutoConfig.from_pretrained(model_path) | |
config.quantization_config = None | |
config.num_labels = 1 # For classification/logit output | |
# IMPORTANT: Do not pass num_labels directly into from_pretrained for LLaMA-based models | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_path, | |
config=config, | |
torch_dtype=torch.float32, # Use full-precision float32 | |
device_map=None, # We'll move it manually | |
low_cpu_mem_usage=False | |
) | |
model.to(device) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
text = ( | |
f"Given a certain paper,\n" | |
f"Title: {title.strip()}\n" | |
f"Abstract: {abstract.strip()}\n" | |
f"Predict its normalized academic impact (0~1):" | |
) | |
try: | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
output = model(**inputs) | |
logits = output.logits | |
prob = torch.sigmoid(logits).item() | |
score = min(1.0, prob + 0.05) # +0.05 offset, capped at 1.0 | |
return round(score, 4) | |
except Exception as e: | |
print(f"Prediction error: {e}") | |
return 0.0 # Return 0 in case of any error | |
def get_grade_and_emoji(score): | |
"""Convert a 0β1 score into a tier grade with emoji indicator.""" | |
if score >= 0.900: return "AAA π" | |
if score >= 0.800: return "AA β" | |
if score >= 0.650: return "A β¨" | |
if score >= 0.600: return "BBB π΅" | |
if score >= 0.550: return "BB π" | |
if score >= 0.500: return "B π" | |
if score >= 0.400: return "CCC π" | |
if score >= 0.300: return "CC βοΈ" | |
return "C π" | |
def validate_input(title, abstract): | |
""" | |
Ensure title >=3 words, abstract >=50 words, and only ASCII chars. | |
""" | |
non_ascii = re.compile(r"[^\x00-\x7F]") | |
if len(title.split()) < 3: | |
return False, "Title must be at least 3 words." | |
if len(abstract.split()) < 50: | |
return False, "Abstract must be at least 50 words." | |
if non_ascii.search(title): | |
return False, "Title contains non-ASCII characters." | |
if non_ascii.search(abstract): | |
return False, "Abstract contains non-ASCII characters." | |
return True, "Inputs look good." | |
def update_button_status(title, abstract): | |
valid, msg = validate_input(title, abstract) | |
if not valid: | |
return gr.update(value="Error: " + msg), gr.update(interactive=False) | |
return gr.update(value=msg), gr.update(interactive=True) | |
def process_arxiv_input(arxiv_input): | |
""" | |
Helper to fill in title/abstract fields from an arXiv link/ID. | |
""" | |
if not arxiv_input.strip(): | |
return "", "", "Please enter an arXiv URL or ID" | |
result = fetch_arxiv_paper(arxiv_input) | |
if result["success"]: | |
return result["title"], result["abstract"], result["message"] | |
return "", "", result["message"] | |
# Custom CSS for styling | |
css = """ | |
.gradio-container { font-family: Arial, sans-serif; } | |
.main-title { | |
text-align: center; | |
color: #2563eb; | |
font-size: 2.5rem !important; | |
margin-bottom: 1rem !important; | |
background: linear-gradient(45deg, #2563eb, #1d4ed8); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
} | |
.sub-title { | |
text-align: center; | |
color: #4b5563; | |
font-size: 1.5rem !important; | |
margin-bottom: 2rem !important; | |
} | |
.input-section { | |
background: white; | |
padding: 2rem; | |
border-radius: 1rem; | |
box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1); | |
} | |
.result-section { | |
background: #f8fafc; | |
padding: 2rem; | |
border-radius: 1rem; | |
margin-top: 2rem; | |
} | |
.methodology-section { | |
background: #ecfdf5; | |
padding: 2rem; | |
border-radius: 1rem; | |
margin-top: 2rem; | |
} | |
.example-section { | |
background: #fff7ed; | |
padding: 2rem; | |
border-radius: 1rem; | |
margin-top: 2rem; | |
} | |
.grade-display { | |
font-size: 3rem; | |
text-align: center; | |
margin: 1rem 0; | |
} | |
.arxiv-input { | |
margin-bottom: 1.5rem; | |
padding: 1rem; | |
background: #f3f4f6; | |
border-radius: 0.5rem; | |
} | |
.arxiv-link { | |
color: #2563eb; | |
text-decoration: underline; | |
font-size: 0.9em; | |
margin-top: 0.5em; | |
} | |
.arxiv-note { | |
color: #666; | |
font-size: 0.9em; | |
margin-top: 0.5em; | |
margin-bottom: 0.5em; | |
} | |
""" | |
# Example papers | |
example_papers = [ | |
{ | |
"title": "Attention Is All You Need", | |
"abstract": ( | |
"The dominant sequence transduction models are based on complex recurrent or " | |
"convolutional neural networks that include an encoder and a decoder. The best performing " | |
"models also connect the encoder and decoder through an attention mechanism. We propose a " | |
"new simple network architecture, the Transformer, based solely on attention mechanisms, " | |
"dispensing with recurrence and convolutions entirely. Experiments on two machine " | |
"translation tasks show these models to be superior in quality while being more " | |
"parallelizable and requiring significantly less time to train." | |
), | |
"score": 0.982, | |
"note": "π« Revolutionary paper that introduced the Transformer architecture." | |
}, | |
{ | |
"title": "Language Models are Few-Shot Learners", | |
"abstract": ( | |
"Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by " | |
"pre-training on a large corpus of text followed by fine-tuning on a specific task. While " | |
"typically task-agnostic in architecture, this method still requires task-specific " | |
"fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans " | |
"can generally perform a new language task from only a few examples or from simple " | |
"instructions - something which current NLP systems still largely struggle to do. Here we " | |
"show that scaling up language models greatly improves task-agnostic, few-shot " | |
"performance, sometimes even reaching competitiveness with prior state-of-the-art " | |
"fine-tuning approaches." | |
), | |
"score": 0.956, | |
"note": "π Groundbreaking GPT-3 paper that demonstrated the power of large language models." | |
}, | |
{ | |
"title": "An Empirical Study of Neural Network Training Protocols", | |
"abstract": ( | |
"This paper presents a comparative analysis of different training protocols for neural " | |
"networks across various architectures. We examine the effects of learning rate schedules, " | |
"batch size selection, and optimization algorithms on model convergence and final " | |
"performance. Our experiments span multiple datasets and model sizes, providing practical " | |
"insights for deep learning practitioners." | |
), | |
"score": 0.623, | |
"note": "π Solid research paper with useful findings but more limited scope and impact." | |
} | |
] | |
# Build Gradio interface | |
with gr.Blocks(theme=gr.themes.Default(), css=css) as iface: | |
gr.Markdown( | |
""" | |
# Papers Impact: AI-Powered Research Impact Predictor | |
## https://discord.gg/openfreeai | |
""" | |
) | |
gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space"> | |
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space&countColor=%23263759" /> | |
</a>""") | |
with gr.Row(): | |
with gr.Column(elem_classes="input-section"): | |
# arXiv import group | |
with gr.Group(elem_classes="arxiv-input"): | |
gr.Markdown("### π Import from arXiv") | |
arxiv_input = gr.Textbox( | |
lines=1, | |
placeholder="Enter arXiv URL or ID (e.g., 2504.11651)", | |
label="arXiv Paper URL/ID", | |
value="2504.11651" | |
) | |
gr.Markdown( | |
""" | |
<p class="arxiv-note"> | |
Click input field to use example paper or browse papers at | |
<a href="https://arxiv.org" target="_blank" class="arxiv-link">arxiv.org</a> | |
</p> | |
""" | |
) | |
fetch_button = gr.Button("π Fetch Paper Details", variant="secondary") | |
gr.Markdown("### π Or Enter Paper Details Manually") | |
title_input = gr.Textbox( | |
lines=2, | |
placeholder="Enter Paper Title (minimum 3 words)...", | |
label="Paper Title" | |
) | |
abstract_input = gr.Textbox( | |
lines=5, | |
placeholder="Enter Paper Abstract (minimum 50 words)...", | |
label="Paper Abstract" | |
) | |
validation_status = gr.Textbox(label="βοΈ Validation Status", interactive=False) | |
submit_button = gr.Button("π― Predict Impact", interactive=False, variant="primary") | |
with gr.Column(elem_classes="result-section"): | |
with gr.Group(): | |
score_output = gr.Number(label="π― Impact Score") | |
grade_output = gr.Textbox(label="π Grade", value="", elem_classes="grade-display") | |
with gr.Row(elem_classes="methodology-section"): | |
gr.Markdown( | |
""" | |
### π¬ Scientific Methodology | |
- **Training Data**: Model trained on extensive dataset of published papers from CS.CV, CS.CL(NLP), and CS.AI fields | |
- **Optimization**: NDCG optimization with Sigmoid activation and MSE loss function | |
- **Validation**: Cross-validated against historical paper impact data | |
- **Architecture**: Advanced transformer-based deep textual analysis | |
- **Metrics**: Quantitative analysis of citation patterns and research influence | |
""" | |
) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
### π Rating Scale | |
| Grade | Score Range | Description | Indicator | | |
|-------|-------------|-------------|-----------| | |
| AAA | 0.900-1.000 | Exceptional Impact | π | | |
| AA | 0.800-0.899 | Very High Impact | β | | |
| A | 0.650-0.799 | High Impact | β¨ | | |
| BBB | 0.600-0.649 | Above Average Impact | π΅ | | |
| BB | 0.550-0.599 | Moderate Impact | π | | |
| B | 0.500-0.549 | Average Impact | π | | |
| CCC | 0.400-0.499 | Below Average Impact | π | | |
| CC | 0.300-0.399 | Low Impact | βοΈ | | |
| C | < 0.299 | Limited Impact | π | | |
""" | |
) | |
with gr.Row(elem_classes="example-section"): | |
gr.Markdown("### π Example Papers") | |
for paper in example_papers: | |
gr.Markdown( | |
f""" | |
#### {paper['title']} | |
**Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))} | |
{paper['abstract']} | |
*{paper['note']}* | |
--- | |
""" | |
) | |
# Validate button status on input changes | |
title_input.change( | |
update_button_status, | |
inputs=[title_input, abstract_input], | |
outputs=[validation_status, submit_button] | |
) | |
abstract_input.change( | |
update_button_status, | |
inputs=[title_input, abstract_input], | |
outputs=[validation_status, submit_button] | |
) | |
# Fetch from arXiv | |
fetch_button.click( | |
process_arxiv_input, | |
inputs=[arxiv_input], | |
outputs=[title_input, abstract_input, validation_status] | |
) | |
# Predict callback | |
def process_prediction(title, abstract): | |
score = predict(title, abstract) | |
grade = get_grade_and_emoji(score) | |
return score, grade | |
submit_button.click( | |
process_prediction, | |
inputs=[title_input, abstract_input], | |
outputs=[score_output, grade_output] | |
) | |
if __name__ == "__main__": | |
iface.launch() | |