Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,32 +6,61 @@ import torch.nn.functional as F
|
|
6 |
import torch.nn as nn
|
7 |
import re
|
8 |
import requests
|
9 |
-
import
|
|
|
10 |
|
11 |
model_path = r'ssocean/NAIP'
|
12 |
-
device = 'cuda
|
13 |
|
14 |
global model, tokenizer
|
15 |
model = None
|
16 |
tokenizer = None
|
17 |
|
18 |
def fetch_arxiv_paper(arxiv_input):
|
19 |
-
"""Fetch paper details from arXiv URL or ID."""
|
20 |
try:
|
21 |
# Extract arXiv ID from URL or use directly
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
paper = next(search.results())
|
31 |
|
32 |
return {
|
33 |
-
"title":
|
34 |
-
"abstract":
|
35 |
"success": True,
|
36 |
"message": "Paper fetched successfully!"
|
37 |
}
|
@@ -50,10 +79,12 @@ def predict(title, abstract):
|
|
50 |
global model, tokenizer
|
51 |
if model is None:
|
52 |
model = AutoModelForSequenceClassification.from_pretrained(
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
57 |
model.eval()
|
58 |
text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
|
59 |
inputs = tokenizer(text, return_tensors="pt").to(device)
|
@@ -177,6 +208,17 @@ css = """
|
|
177 |
border-radius: 1rem;
|
178 |
margin-top: 2rem;
|
179 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
"""
|
181 |
|
182 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
@@ -190,12 +232,14 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
190 |
with gr.Row():
|
191 |
with gr.Column(elem_classes="input-section"):
|
192 |
# arXiv Input
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
199 |
|
200 |
gr.Markdown("### π Or Enter Paper Details Manually")
|
201 |
|
@@ -213,8 +257,9 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
213 |
submit_button = gr.Button("π― Predict Impact", interactive=False, variant="primary")
|
214 |
|
215 |
with gr.Column(elem_classes="result-section"):
|
216 |
-
|
217 |
-
|
|
|
218 |
|
219 |
with gr.Row(elem_classes="methodology-section"):
|
220 |
gr.Markdown(
|
|
|
6 |
import torch.nn as nn
|
7 |
import re
|
8 |
import requests
|
9 |
+
from urllib.parse import urlparse
|
10 |
+
import xml.etree.ElementTree as ET
|
11 |
|
12 |
model_path = r'ssocean/NAIP'
|
13 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
14 |
|
15 |
global model, tokenizer
|
16 |
model = None
|
17 |
tokenizer = None
|
18 |
|
19 |
def fetch_arxiv_paper(arxiv_input):
|
20 |
+
"""Fetch paper details from arXiv URL or ID using requests."""
|
21 |
try:
|
22 |
# Extract arXiv ID from URL or use directly
|
23 |
+
if 'arxiv.org' in arxiv_input:
|
24 |
+
parsed = urlparse(arxiv_input)
|
25 |
+
path = parsed.path
|
26 |
+
arxiv_id = path.split('/')[-1].replace('.pdf', '')
|
27 |
+
else:
|
28 |
+
arxiv_id = arxiv_input.strip()
|
29 |
+
|
30 |
+
# Fetch metadata using arXiv API
|
31 |
+
api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
|
32 |
+
response = requests.get(api_url)
|
33 |
+
|
34 |
+
if response.status_code != 200:
|
35 |
+
return {
|
36 |
+
"title": "",
|
37 |
+
"abstract": "",
|
38 |
+
"success": False,
|
39 |
+
"message": "Error fetching paper from arXiv API"
|
40 |
+
}
|
41 |
+
|
42 |
+
# Parse the response XML
|
43 |
+
root = ET.fromstring(response.text)
|
44 |
+
|
45 |
+
# ArXiv API uses namespaces
|
46 |
+
ns = {'arxiv': 'http://www.w3.org/2005/Atom'}
|
47 |
+
|
48 |
+
# Extract title and abstract
|
49 |
+
entry = root.find('.//arxiv:entry', ns)
|
50 |
+
if entry is None:
|
51 |
+
return {
|
52 |
+
"title": "",
|
53 |
+
"abstract": "",
|
54 |
+
"success": False,
|
55 |
+
"message": "Paper not found"
|
56 |
+
}
|
57 |
|
58 |
+
title = entry.find('arxiv:title', ns).text.strip()
|
59 |
+
abstract = entry.find('arxiv:summary', ns).text.strip()
|
|
|
60 |
|
61 |
return {
|
62 |
+
"title": title,
|
63 |
+
"abstract": abstract,
|
64 |
"success": True,
|
65 |
"message": "Paper fetched successfully!"
|
66 |
}
|
|
|
79 |
global model, tokenizer
|
80 |
if model is None:
|
81 |
model = AutoModelForSequenceClassification.from_pretrained(
|
82 |
+
model_path,
|
83 |
+
num_labels=1,
|
84 |
+
torch_dtype=torch.float32 if device == 'cpu' else torch.float16
|
85 |
+
)
|
86 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
87 |
+
model.to(device)
|
88 |
model.eval()
|
89 |
text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
|
90 |
inputs = tokenizer(text, return_tensors="pt").to(device)
|
|
|
208 |
border-radius: 1rem;
|
209 |
margin-top: 2rem;
|
210 |
}
|
211 |
+
.grade-display {
|
212 |
+
font-size: 3rem;
|
213 |
+
text-align: center;
|
214 |
+
margin: 1rem 0;
|
215 |
+
}
|
216 |
+
.arxiv-input {
|
217 |
+
margin-bottom: 1.5rem;
|
218 |
+
padding: 1rem;
|
219 |
+
background: #f3f4f6;
|
220 |
+
border-radius: 0.5rem;
|
221 |
+
}
|
222 |
"""
|
223 |
|
224 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
|
232 |
with gr.Row():
|
233 |
with gr.Column(elem_classes="input-section"):
|
234 |
# arXiv Input
|
235 |
+
with gr.Group(elem_classes="arxiv-input"):
|
236 |
+
gr.Markdown("### π Import from arXiv")
|
237 |
+
arxiv_input = gr.Textbox(
|
238 |
+
lines=1,
|
239 |
+
placeholder="Enter arXiv URL or ID (e.g., 2006.16236 or https://arxiv.org/abs/2006.16236)",
|
240 |
+
label="arXiv Paper URL/ID"
|
241 |
+
)
|
242 |
+
fetch_button = gr.Button("π Fetch Paper Details", variant="secondary")
|
243 |
|
244 |
gr.Markdown("### π Or Enter Paper Details Manually")
|
245 |
|
|
|
257 |
submit_button = gr.Button("π― Predict Impact", interactive=False, variant="primary")
|
258 |
|
259 |
with gr.Column(elem_classes="result-section"):
|
260 |
+
with gr.Group():
|
261 |
+
score_output = gr.Number(label="π― Impact Score")
|
262 |
+
grade_output = gr.Textbox(label="π Grade", value="", elem_classes="grade-display")
|
263 |
|
264 |
with gr.Row(elem_classes="methodology-section"):
|
265 |
gr.Markdown(
|