openfree commited on
Commit
a12abfd
Β·
verified Β·
1 Parent(s): 3544bdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -24
app.py CHANGED
@@ -6,32 +6,61 @@ import torch.nn.functional as F
6
  import torch.nn as nn
7
  import re
8
  import requests
9
- import arxiv
 
10
 
11
  model_path = r'ssocean/NAIP'
12
- device = 'cuda:0'
13
 
14
  global model, tokenizer
15
  model = None
16
  tokenizer = None
17
 
18
  def fetch_arxiv_paper(arxiv_input):
19
- """Fetch paper details from arXiv URL or ID."""
20
  try:
21
  # Extract arXiv ID from URL or use directly
22
- arxiv_id = arxiv_input.split('/')[-1]
23
- if 'abs' in arxiv_id:
24
- arxiv_id = arxiv_id.split('abs/')[-1]
25
- if '.pdf' in arxiv_id:
26
- arxiv_id = arxiv_id.replace('.pdf', '')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Search for the paper
29
- search = arxiv.Search(id_list=[arxiv_id])
30
- paper = next(search.results())
31
 
32
  return {
33
- "title": paper.title,
34
- "abstract": paper.summary,
35
  "success": True,
36
  "message": "Paper fetched successfully!"
37
  }
@@ -50,10 +79,12 @@ def predict(title, abstract):
50
  global model, tokenizer
51
  if model is None:
52
  model = AutoModelForSequenceClassification.from_pretrained(
53
- model_path,
54
- num_labels=1,
55
- load_in_8bit=True,)
 
56
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
57
  model.eval()
58
  text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
59
  inputs = tokenizer(text, return_tensors="pt").to(device)
@@ -177,6 +208,17 @@ css = """
177
  border-radius: 1rem;
178
  margin-top: 2rem;
179
  }
 
 
 
 
 
 
 
 
 
 
 
180
  """
181
 
182
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
@@ -190,12 +232,14 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
190
  with gr.Row():
191
  with gr.Column(elem_classes="input-section"):
192
  # arXiv Input
193
- arxiv_input = gr.Textbox(
194
- lines=1,
195
- placeholder="Enter arXiv URL or ID (e.g., 2006.16236 or https://arxiv.org/abs/2006.16236)",
196
- label="πŸ“‘ arXiv Paper URL/ID"
197
- )
198
- fetch_button = gr.Button("πŸ” Fetch Paper Details", variant="secondary")
 
 
199
 
200
  gr.Markdown("### πŸ“ Or Enter Paper Details Manually")
201
 
@@ -213,8 +257,9 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
213
  submit_button = gr.Button("🎯 Predict Impact", interactive=False, variant="primary")
214
 
215
  with gr.Column(elem_classes="result-section"):
216
- score_output = gr.Number(label="🎯 Impact Score")
217
- grade_output = gr.Textbox(label="πŸ† Grade", value="")
 
218
 
219
  with gr.Row(elem_classes="methodology-section"):
220
  gr.Markdown(
 
6
  import torch.nn as nn
7
  import re
8
  import requests
9
+ from urllib.parse import urlparse
10
+ import xml.etree.ElementTree as ET
11
 
12
  model_path = r'ssocean/NAIP'
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
 
15
  global model, tokenizer
16
  model = None
17
  tokenizer = None
18
 
19
  def fetch_arxiv_paper(arxiv_input):
20
+ """Fetch paper details from arXiv URL or ID using requests."""
21
  try:
22
  # Extract arXiv ID from URL or use directly
23
+ if 'arxiv.org' in arxiv_input:
24
+ parsed = urlparse(arxiv_input)
25
+ path = parsed.path
26
+ arxiv_id = path.split('/')[-1].replace('.pdf', '')
27
+ else:
28
+ arxiv_id = arxiv_input.strip()
29
+
30
+ # Fetch metadata using arXiv API
31
+ api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
32
+ response = requests.get(api_url)
33
+
34
+ if response.status_code != 200:
35
+ return {
36
+ "title": "",
37
+ "abstract": "",
38
+ "success": False,
39
+ "message": "Error fetching paper from arXiv API"
40
+ }
41
+
42
+ # Parse the response XML
43
+ root = ET.fromstring(response.text)
44
+
45
+ # ArXiv API uses namespaces
46
+ ns = {'arxiv': 'http://www.w3.org/2005/Atom'}
47
+
48
+ # Extract title and abstract
49
+ entry = root.find('.//arxiv:entry', ns)
50
+ if entry is None:
51
+ return {
52
+ "title": "",
53
+ "abstract": "",
54
+ "success": False,
55
+ "message": "Paper not found"
56
+ }
57
 
58
+ title = entry.find('arxiv:title', ns).text.strip()
59
+ abstract = entry.find('arxiv:summary', ns).text.strip()
 
60
 
61
  return {
62
+ "title": title,
63
+ "abstract": abstract,
64
  "success": True,
65
  "message": "Paper fetched successfully!"
66
  }
 
79
  global model, tokenizer
80
  if model is None:
81
  model = AutoModelForSequenceClassification.from_pretrained(
82
+ model_path,
83
+ num_labels=1,
84
+ torch_dtype=torch.float32 if device == 'cpu' else torch.float16
85
+ )
86
  tokenizer = AutoTokenizer.from_pretrained(model_path)
87
+ model.to(device)
88
  model.eval()
89
  text = f'''Given a certain paper, Title: {title}\n Abstract: {abstract}. \n Predict its normalized academic impact (between 0 and 1):'''
90
  inputs = tokenizer(text, return_tensors="pt").to(device)
 
208
  border-radius: 1rem;
209
  margin-top: 2rem;
210
  }
211
+ .grade-display {
212
+ font-size: 3rem;
213
+ text-align: center;
214
+ margin: 1rem 0;
215
+ }
216
+ .arxiv-input {
217
+ margin-bottom: 1.5rem;
218
+ padding: 1rem;
219
+ background: #f3f4f6;
220
+ border-radius: 0.5rem;
221
+ }
222
  """
223
 
224
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
 
232
  with gr.Row():
233
  with gr.Column(elem_classes="input-section"):
234
  # arXiv Input
235
+ with gr.Group(elem_classes="arxiv-input"):
236
+ gr.Markdown("### πŸ“‘ Import from arXiv")
237
+ arxiv_input = gr.Textbox(
238
+ lines=1,
239
+ placeholder="Enter arXiv URL or ID (e.g., 2006.16236 or https://arxiv.org/abs/2006.16236)",
240
+ label="arXiv Paper URL/ID"
241
+ )
242
+ fetch_button = gr.Button("πŸ” Fetch Paper Details", variant="secondary")
243
 
244
  gr.Markdown("### πŸ“ Or Enter Paper Details Manually")
245
 
 
257
  submit_button = gr.Button("🎯 Predict Impact", interactive=False, variant="primary")
258
 
259
  with gr.Column(elem_classes="result-section"):
260
+ with gr.Group():
261
+ score_output = gr.Number(label="🎯 Impact Score")
262
+ grade_output = gr.Textbox(label="πŸ† Grade", value="", elem_classes="grade-display")
263
 
264
  with gr.Row(elem_classes="methodology-section"):
265
  gr.Markdown(