openfree commited on
Commit
822a9a7
Β·
verified Β·
1 Parent(s): a00a5a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -91
app.py CHANGED
@@ -9,29 +9,32 @@ import requests
9
  from urllib.parse import urlparse
10
  import xml.etree.ElementTree as ET
11
 
12
- model_path = r'ssocean/NAIP'
13
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
14
 
15
- global model, tokenizer
16
  model = None
17
  tokenizer = None
18
 
19
  def fetch_arxiv_paper(arxiv_input):
20
- """Fetch paper details from arXiv URL or ID using requests."""
 
 
21
  try:
22
- # Extract arXiv ID from URL or use directly
23
- if 'arxiv.org' in arxiv_input:
24
  parsed = urlparse(arxiv_input)
25
  path = parsed.path
26
- arxiv_id = path.split('/')[-1].replace('.pdf', '')
27
  else:
 
28
  arxiv_id = arxiv_input.strip()
29
 
30
- # Fetch metadata using arXiv API
31
- api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
32
- response = requests.get(api_url)
33
-
34
- if response.status_code != 200:
35
  return {
36
  "title": "",
37
  "abstract": "",
@@ -39,14 +42,10 @@ def fetch_arxiv_paper(arxiv_input):
39
  "message": "Error fetching paper from arXiv API"
40
  }
41
 
42
- # Parse the response XML
43
- root = ET.fromstring(response.text)
44
-
45
- # ArXiv API uses Atom namespace
46
- ns = {'arxiv': 'http://www.w3.org/2005/Atom'}
47
-
48
- # Extract title and abstract
49
- entry = root.find('.//arxiv:entry', ns)
50
  if entry is None:
51
  return {
52
  "title": "",
@@ -54,10 +53,9 @@ def fetch_arxiv_paper(arxiv_input):
54
  "success": False,
55
  "message": "Paper not found"
56
  }
57
-
58
- title = entry.find('arxiv:title', ns).text.strip()
59
- abstract = entry.find('arxiv:summary', ns).text.strip()
60
-
61
  return {
62
  "title": title,
63
  "abstract": abstract,
@@ -74,34 +72,35 @@ def fetch_arxiv_paper(arxiv_input):
74
 
75
  @spaces.GPU(duration=60, enable_queue=True)
76
  def predict(title, abstract):
77
- """Predict a normalized academic impact score (0–1) from title & abstract."""
78
- title = title.replace("\n", " ").strip().replace("''", "'")
79
- abstract = abstract.replace("\n", " ").strip().replace("''", "'")
 
80
  global model, tokenizer
81
 
82
  if model is None:
83
- # Load config and disable any quantization
84
  config = AutoConfig.from_pretrained(model_path)
85
  config.quantization_config = None
 
86
 
87
- # Load model in full float32, then move to device
88
  model = AutoModelForSequenceClassification.from_pretrained(
89
  model_path,
90
  config=config,
91
- num_labels=1,
92
- torch_dtype=torch.float32,
93
- device_map=None,
94
  low_cpu_mem_usage=False
95
  )
96
  model.to(device)
 
97
 
98
  tokenizer = AutoTokenizer.from_pretrained(model_path)
99
- model.eval()
100
-
101
  text = (
102
  f"Given a certain paper,\n"
103
- f"Title: {title}\n"
104
- f"Abstract: {abstract}\n"
105
  f"Predict its normalized academic impact (0~1):"
106
  )
107
 
@@ -109,15 +108,17 @@ def predict(title, abstract):
109
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
110
  inputs = {k: v.to(device) for k, v in inputs.items()}
111
  with torch.no_grad():
112
- outputs = model(**inputs)
113
- prob = torch.sigmoid(outputs.logits).item()
114
- score = min(1.0, prob + 0.05)
 
115
  return round(score, 4)
116
  except Exception as e:
117
  print(f"Prediction error: {e}")
118
- return 0.0 # default on error
119
 
120
  def get_grade_and_emoji(score):
 
121
  if score >= 0.900: return "AAA 🌟"
122
  if score >= 0.800: return "AA ⭐"
123
  if score >= 0.650: return "A ✨"
@@ -128,60 +129,41 @@ def get_grade_and_emoji(score):
128
  if score >= 0.300: return "CC ✏️"
129
  return "C πŸ“‘"
130
 
131
- example_papers = [
132
- {
133
- "title": "Attention Is All You Need",
134
- "abstract": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks that include an encoder and a decoder. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train.",
135
- "score": 0.982,
136
- "note": "πŸ’« Revolutionary paper that introduced the Transformer architecture, fundamentally changing NLP and deep learning."
137
- },
138
- {
139
- "title": "Language Models are Few-Shot Learners",
140
- "abstract": "Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches.",
141
- "score": 0.956,
142
- "note": "πŸš€ Groundbreaking GPT-3 paper that demonstrated the power of large language models."
143
- },
144
- {
145
- "title": "An Empirical Study of Neural Network Training Protocols",
146
- "abstract": "This paper presents a comparative analysis of different training protocols for neural networks across various architectures. We examine the effects of learning rate schedules, batch size selection, and optimization algorithms on model convergence and final performance. Our experiments span multiple datasets and model sizes, providing practical insights for deep learning practitioners.",
147
- "score": 0.623,
148
- "note": "πŸ“š Solid research paper with useful findings but more limited scope and impact."
149
- }
150
- ]
151
-
152
  def validate_input(title, abstract):
153
- title = title.replace("\n", " ").strip().replace("''", "'")
154
- abstract = abstract.replace("\n", " ").strip().replace("''", "'")
155
- non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
 
156
  if len(title.split()) < 3:
157
- return False, "The title must be at least 3 words long."
158
  if len(abstract.split()) < 50:
159
- return False, "The abstract must be at least 50 words long."
160
- if non_latin_pattern.search(title):
161
- return False, "The title contains invalid characters. Only English letters and symbols are allowed."
162
- if non_latin_pattern.search(abstract):
163
- return False, "The abstract contains invalid characters. Only English letters and symbols are allowed."
164
- return True, "Inputs are valid!"
165
 
166
  def update_button_status(title, abstract):
167
- valid, message = validate_input(title, abstract)
168
  if not valid:
169
- return gr.update(value="Error: " + message), gr.update(interactive=False)
170
- return gr.update(value=message), gr.update(interactive=True)
171
 
172
  def process_arxiv_input(arxiv_input):
 
 
 
173
  if not arxiv_input.strip():
174
  return "", "", "Please enter an arXiv URL or ID"
175
  result = fetch_arxiv_paper(arxiv_input)
176
  if result["success"]:
177
  return result["title"], result["abstract"], result["message"]
178
- else:
179
- return "", "", result["message"]
180
 
 
181
  css = """
182
- .gradio-container {
183
- font-family: 'Arial', sans-serif;
184
- }
185
  .main-title {
186
  text-align: center;
187
  color: #2563eb;
@@ -201,7 +183,7 @@ css = """
201
  background: white;
202
  padding: 2rem;
203
  border-radius: 1rem;
204
- box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1);
205
  }
206
  .result-section {
207
  background: #f8fafc;
@@ -246,6 +228,53 @@ css = """
246
  }
247
  """
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
250
  gr.Markdown(
251
  """
@@ -259,6 +288,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
259
 
260
  with gr.Row():
261
  with gr.Column(elem_classes="input-section"):
 
262
  with gr.Group(elem_classes="arxiv-input"):
263
  gr.Markdown("### πŸ“‘ Import from arXiv")
264
  arxiv_input = gr.Textbox(
@@ -267,16 +297,17 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
267
  label="arXiv Paper URL/ID",
268
  value="2504.11651"
269
  )
270
- gr.Markdown("""
271
- <p class="arxiv-note">
272
- Click input field to use example paper or browse papers at
273
- <a href="https://arxiv.org" target="_blank" class="arxiv-link">arxiv.org</a>
274
- </p>
275
- """)
 
 
276
  fetch_button = gr.Button("πŸ” Fetch Paper Details", variant="secondary")
277
-
278
  gr.Markdown("### πŸ“ Or Enter Paper Details Manually")
279
-
280
  title_input = gr.Textbox(
281
  lines=2,
282
  placeholder="Enter Paper Title (minimum 3 words)...",
@@ -289,7 +320,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
289
  )
290
  validation_status = gr.Textbox(label="βœ”οΈ Validation Status", interactive=False)
291
  submit_button = gr.Button("🎯 Predict Impact", interactive=False, variant="primary")
292
-
293
  with gr.Column(elem_classes="result-section"):
294
  with gr.Group():
295
  score_output = gr.Number(label="🎯 Impact Score")
@@ -330,7 +361,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
330
  for paper in example_papers:
331
  gr.Markdown(
332
  f"""
333
- #### {paper['title']}
334
  **Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))}
335
  {paper['abstract']}
336
  *{paper['note']}*
@@ -338,6 +369,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
338
  """
339
  )
340
 
 
341
  title_input.change(
342
  update_button_status,
343
  inputs=[title_input, abstract_input],
@@ -348,13 +380,15 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
348
  inputs=[title_input, abstract_input],
349
  outputs=[validation_status, submit_button]
350
  )
351
-
 
352
  fetch_button.click(
353
  process_arxiv_input,
354
  inputs=[arxiv_input],
355
  outputs=[title_input, abstract_input, validation_status]
356
  )
357
 
 
358
  def process_prediction(title, abstract):
359
  score = predict(title, abstract)
360
  grade = get_grade_and_emoji(score)
 
9
  from urllib.parse import urlparse
10
  import xml.etree.ElementTree as ET
11
 
12
+ # Model repository path and device selection
13
+ model_path = "ssocean/NAIP"
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ # Global model/tokenizer variables
17
  model = None
18
  tokenizer = None
19
 
20
  def fetch_arxiv_paper(arxiv_input):
21
+ """
22
+ Fetch paper details (title, abstract) from an arXiv URL or ID using requests.
23
+ """
24
  try:
25
+ # If user passed a full arxiv.org link, parse out the ID
26
+ if "arxiv.org" in arxiv_input:
27
  parsed = urlparse(arxiv_input)
28
  path = parsed.path
29
+ arxiv_id = path.split("/")[-1].replace(".pdf", "")
30
  else:
31
+ # Otherwise just use the raw ID
32
  arxiv_id = arxiv_input.strip()
33
 
34
+ # ArXiv API query
35
+ api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
36
+ resp = requests.get(api_url)
37
+ if resp.status_code != 200:
 
38
  return {
39
  "title": "",
40
  "abstract": "",
 
42
  "message": "Error fetching paper from arXiv API"
43
  }
44
 
45
+ # Parse XML response
46
+ root = ET.fromstring(resp.text)
47
+ ns = {"arxiv": "http://www.w3.org/2005/Atom"}
48
+ entry = root.find(".//arxiv:entry", ns)
 
 
 
 
49
  if entry is None:
50
  return {
51
  "title": "",
 
53
  "success": False,
54
  "message": "Paper not found"
55
  }
56
+
57
+ title = entry.find("arxiv:title", ns).text.strip()
58
+ abstract = entry.find("arxiv:summary", ns).text.strip()
 
59
  return {
60
  "title": title,
61
  "abstract": abstract,
 
72
 
73
  @spaces.GPU(duration=60, enable_queue=True)
74
  def predict(title, abstract):
75
+ """
76
+ Predict a normalized academic impact score (0–1) given the paper title & abstract.
77
+ Loads the model once globally, then uses it for inference.
78
+ """
79
  global model, tokenizer
80
 
81
  if model is None:
82
+ # Load model config, disable quantization, and set number of labels if needed
83
  config = AutoConfig.from_pretrained(model_path)
84
  config.quantization_config = None
85
+ config.num_labels = 1 # For classification/logit output
86
 
87
+ # IMPORTANT: Do not pass num_labels directly into from_pretrained for LLaMA-based models
88
  model = AutoModelForSequenceClassification.from_pretrained(
89
  model_path,
90
  config=config,
91
+ torch_dtype=torch.float32, # Use full-precision float32
92
+ device_map=None, # We'll move it manually
 
93
  low_cpu_mem_usage=False
94
  )
95
  model.to(device)
96
+ model.eval()
97
 
98
  tokenizer = AutoTokenizer.from_pretrained(model_path)
99
+
 
100
  text = (
101
  f"Given a certain paper,\n"
102
+ f"Title: {title.strip()}\n"
103
+ f"Abstract: {abstract.strip()}\n"
104
  f"Predict its normalized academic impact (0~1):"
105
  )
106
 
 
108
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
109
  inputs = {k: v.to(device) for k, v in inputs.items()}
110
  with torch.no_grad():
111
+ output = model(**inputs)
112
+ logits = output.logits
113
+ prob = torch.sigmoid(logits).item()
114
+ score = min(1.0, prob + 0.05) # +0.05 offset, capped at 1.0
115
  return round(score, 4)
116
  except Exception as e:
117
  print(f"Prediction error: {e}")
118
+ return 0.0 # Return 0 in case of any error
119
 
120
  def get_grade_and_emoji(score):
121
+ """Convert a 0–1 score into a tier grade with emoji indicator."""
122
  if score >= 0.900: return "AAA 🌟"
123
  if score >= 0.800: return "AA ⭐"
124
  if score >= 0.650: return "A ✨"
 
129
  if score >= 0.300: return "CC ✏️"
130
  return "C πŸ“‘"
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def validate_input(title, abstract):
133
+ """
134
+ Ensure title >=3 words, abstract >=50 words, and only ASCII chars.
135
+ """
136
+ non_ascii = re.compile(r"[^\x00-\x7F]")
137
  if len(title.split()) < 3:
138
+ return False, "Title must be at least 3 words."
139
  if len(abstract.split()) < 50:
140
+ return False, "Abstract must be at least 50 words."
141
+ if non_ascii.search(title):
142
+ return False, "Title contains non-ASCII characters."
143
+ if non_ascii.search(abstract):
144
+ return False, "Abstract contains non-ASCII characters."
145
+ return True, "Inputs look good."
146
 
147
  def update_button_status(title, abstract):
148
+ valid, msg = validate_input(title, abstract)
149
  if not valid:
150
+ return gr.update(value="Error: " + msg), gr.update(interactive=False)
151
+ return gr.update(value=msg), gr.update(interactive=True)
152
 
153
  def process_arxiv_input(arxiv_input):
154
+ """
155
+ Helper to fill in title/abstract fields from an arXiv link/ID.
156
+ """
157
  if not arxiv_input.strip():
158
  return "", "", "Please enter an arXiv URL or ID"
159
  result = fetch_arxiv_paper(arxiv_input)
160
  if result["success"]:
161
  return result["title"], result["abstract"], result["message"]
162
+ return "", "", result["message"]
 
163
 
164
+ # Custom CSS for styling
165
  css = """
166
+ .gradio-container { font-family: Arial, sans-serif; }
 
 
167
  .main-title {
168
  text-align: center;
169
  color: #2563eb;
 
183
  background: white;
184
  padding: 2rem;
185
  border-radius: 1rem;
186
+ box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);
187
  }
188
  .result-section {
189
  background: #f8fafc;
 
228
  }
229
  """
230
 
231
+ # Example papers
232
+ example_papers = [
233
+ {
234
+ "title": "Attention Is All You Need",
235
+ "abstract": (
236
+ "The dominant sequence transduction models are based on complex recurrent or "
237
+ "convolutional neural networks that include an encoder and a decoder. The best performing "
238
+ "models also connect the encoder and decoder through an attention mechanism. We propose a "
239
+ "new simple network architecture, the Transformer, based solely on attention mechanisms, "
240
+ "dispensing with recurrence and convolutions entirely. Experiments on two machine "
241
+ "translation tasks show these models to be superior in quality while being more "
242
+ "parallelizable and requiring significantly less time to train."
243
+ ),
244
+ "score": 0.982,
245
+ "note": "πŸ’« Revolutionary paper that introduced the Transformer architecture."
246
+ },
247
+ {
248
+ "title": "Language Models are Few-Shot Learners",
249
+ "abstract": (
250
+ "Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by "
251
+ "pre-training on a large corpus of text followed by fine-tuning on a specific task. While "
252
+ "typically task-agnostic in architecture, this method still requires task-specific "
253
+ "fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans "
254
+ "can generally perform a new language task from only a few examples or from simple "
255
+ "instructions - something which current NLP systems still largely struggle to do. Here we "
256
+ "show that scaling up language models greatly improves task-agnostic, few-shot "
257
+ "performance, sometimes even reaching competitiveness with prior state-of-the-art "
258
+ "fine-tuning approaches."
259
+ ),
260
+ "score": 0.956,
261
+ "note": "πŸš€ Groundbreaking GPT-3 paper that demonstrated the power of large language models."
262
+ },
263
+ {
264
+ "title": "An Empirical Study of Neural Network Training Protocols",
265
+ "abstract": (
266
+ "This paper presents a comparative analysis of different training protocols for neural "
267
+ "networks across various architectures. We examine the effects of learning rate schedules, "
268
+ "batch size selection, and optimization algorithms on model convergence and final "
269
+ "performance. Our experiments span multiple datasets and model sizes, providing practical "
270
+ "insights for deep learning practitioners."
271
+ ),
272
+ "score": 0.623,
273
+ "note": "πŸ“š Solid research paper with useful findings but more limited scope and impact."
274
+ }
275
+ ]
276
+
277
+ # Build Gradio interface
278
  with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
279
  gr.Markdown(
280
  """
 
288
 
289
  with gr.Row():
290
  with gr.Column(elem_classes="input-section"):
291
+ # arXiv import group
292
  with gr.Group(elem_classes="arxiv-input"):
293
  gr.Markdown("### πŸ“‘ Import from arXiv")
294
  arxiv_input = gr.Textbox(
 
297
  label="arXiv Paper URL/ID",
298
  value="2504.11651"
299
  )
300
+ gr.Markdown(
301
+ """
302
+ <p class="arxiv-note">
303
+ Click input field to use example paper or browse papers at
304
+ <a href="https://arxiv.org" target="_blank" class="arxiv-link">arxiv.org</a>
305
+ </p>
306
+ """
307
+ )
308
  fetch_button = gr.Button("πŸ” Fetch Paper Details", variant="secondary")
309
+
310
  gr.Markdown("### πŸ“ Or Enter Paper Details Manually")
 
311
  title_input = gr.Textbox(
312
  lines=2,
313
  placeholder="Enter Paper Title (minimum 3 words)...",
 
320
  )
321
  validation_status = gr.Textbox(label="βœ”οΈ Validation Status", interactive=False)
322
  submit_button = gr.Button("🎯 Predict Impact", interactive=False, variant="primary")
323
+
324
  with gr.Column(elem_classes="result-section"):
325
  with gr.Group():
326
  score_output = gr.Number(label="🎯 Impact Score")
 
361
  for paper in example_papers:
362
  gr.Markdown(
363
  f"""
364
+ #### {paper['title']}
365
  **Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))}
366
  {paper['abstract']}
367
  *{paper['note']}*
 
369
  """
370
  )
371
 
372
+ # Validate button status on input changes
373
  title_input.change(
374
  update_button_status,
375
  inputs=[title_input, abstract_input],
 
380
  inputs=[title_input, abstract_input],
381
  outputs=[validation_status, submit_button]
382
  )
383
+
384
+ # Fetch from arXiv
385
  fetch_button.click(
386
  process_arxiv_input,
387
  inputs=[arxiv_input],
388
  outputs=[title_input, abstract_input, validation_status]
389
  )
390
 
391
+ # Predict callback
392
  def process_prediction(title, abstract):
393
  score = predict(title, abstract)
394
  grade = get_grade_and_emoji(score)