akhaliq HF staff commited on
Commit
22d5f09
Β·
verified Β·
1 Parent(s): 33ea647

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -8
app.py CHANGED
@@ -8,6 +8,7 @@ from huggingface_hub import HfApi
8
 
9
  import gradio as gr
10
  import datasets # Ensure the datasets library is imported
 
11
 
12
  from datetime import timezone
13
  import atexit # To gracefully shut down the scheduler
@@ -21,10 +22,37 @@ logger = logging.getLogger(__name__)
21
 
22
  api = HfApi()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def get_df() -> pd.DataFrame:
25
  """
26
  Loads and merges the papers and stats datasets, preprocesses the data by removing unnecessary columns,
27
- and adds a 'paper_page' link for each paper.
28
  """
29
  try:
30
  # Load datasets
@@ -52,6 +80,17 @@ def get_df() -> pd.DataFrame:
52
  info = row.copy()
53
  if "abstract" in info:
54
  del info["abstract"]
 
 
 
 
 
 
 
 
 
 
 
55
  paper_info.append(info)
56
  df_prepared = pd.DataFrame(paper_info)
57
 
@@ -70,7 +109,11 @@ class Prettifier:
70
  """
71
  Converts raw DataFrame rows into a prettified format suitable for display.
72
  """
73
- REQUIRED_COLUMNS = ["arxiv_id", "date_display", "date", "paper_page", "title", "github", "πŸ‘", "πŸ’¬"]
 
 
 
 
74
 
75
  @staticmethod
76
  def get_github_link(link: str) -> str:
@@ -97,6 +140,9 @@ class Prettifier:
97
  "github": Prettifier.get_github_link(row.get("github", "")),
98
  "πŸ‘": row.get("upvotes", 0),
99
  "πŸ’¬": row.get("num_comments", 0),
 
 
 
100
  }
101
  new_rows.append(new_row)
102
 
@@ -120,6 +166,9 @@ class PaperList:
120
  ["github", "markdown"],
121
  ["πŸ‘", "number"],
122
  ["πŸ’¬", "number"],
 
 
 
123
  ]
124
 
125
  def __init__(self, df: pd.DataFrame):
@@ -212,6 +261,12 @@ class PaperManager:
212
  df_sorted['date_parsed'] = pd.to_datetime(df_sorted['date'], errors='coerce').dt.tz_localize(timezone.utc, ambiguous='NaT', nonexistent='NaT')
213
  df_sorted = df_sorted[df_sorted['date_parsed'] >= time_threshold]
214
  df_sorted = df_sorted.sort_values(by='upvotes', ascending=False).drop(columns=['date_parsed'])
 
 
 
 
 
 
215
  else:
216
  df_sorted = df
217
 
@@ -219,13 +274,15 @@ class PaperManager:
219
  self.paper_list.df_prettified = self.paper_list._prettifier(self.paper_list.df_raw).loc[:, self.paper_list.column_names]
220
  self.total_pages = max((len(self.paper_list.df_raw) + self.papers_per_page - 1) // self.papers_per_page, 1)
221
  self.current_page = 1
 
222
 
223
  def set_sort_method(self, method, time_frame=None):
224
  """
225
- Sets the sort method ('hot', 'new', 'top') and re-sorts the papers.
226
  If 'top' is selected, also sets the time frame.
227
  """
228
- if method not in ["hot", "new", "top"]:
 
229
  method = "hot"
230
  logger.info(f"Setting sort method to: {method}")
231
  self.sort_method = method
@@ -262,6 +319,9 @@ class PaperManager:
262
  url = f"https://huggingface.co/papers/{paper_id}"
263
  upvotes = row.get('πŸ‘', 0)
264
  comments = row.get('πŸ’¬', 0)
 
 
 
265
  date_str = row.get('date', datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d"))
266
  try:
267
  published_time = datetime.datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
@@ -282,7 +342,8 @@ class PaperManager:
282
  <td colspan="1"></td>
283
  <td class="subtext">
284
  <span class="score">{upvotes} upvotes</span><br>
285
- {time_ago} | <a href="#">{comments} comments</a>
 
286
  </td>
287
  </tr>
288
  <tr style="height:5px"></tr>
@@ -325,14 +386,16 @@ def initialize_paper_manager() -> str:
325
  Initializes the PaperList and PaperManager with the current DataFrame.
326
  """
327
  df = get_df()
 
 
328
  paper_list = PaperList(df)
329
  manager = PaperManager(paper_list)
 
330
  return manager.get_current_page_papers() # Return HTML string instead of the manager object
331
 
332
 
333
  paper_manager = None # Initialize globally
334
 
335
-
336
  def setup_paper_manager():
337
  """
338
  Sets up the global PaperManager instance.
@@ -388,7 +451,9 @@ def change_sort_method_ui(method: str, time_frame: str = "all time") -> str:
388
  Changes the sort method and, if 'top' is selected, sets the time frame.
389
  """
390
  logger.info(f"Changing sort method to: {method} with time frame: {time_frame}")
391
- if method.lower() == "top":
 
 
392
  paper_manager.set_sort_method(method.lower(), time_frame)
393
  else:
394
  paper_manager.set_sort_method(method.lower())
@@ -529,6 +594,7 @@ table {
529
  }
530
  """
531
 
 
532
  # --- Initialize Gradio Blocks ---
533
 
534
  demo = gr.Blocks(css=css)
@@ -562,7 +628,7 @@ with demo:
562
  # Sort Options and Time Frame (conditionally visible)
563
  with gr.Row():
564
  sort_radio = gr.Radio(
565
- choices=["Hot", "New", "Top"],
566
  value="Hot",
567
  label="Sort By",
568
  interactive=True
 
8
 
9
  import gradio as gr
10
  import datasets # Ensure the datasets library is imported
11
+ import requests # For making API calls
12
 
13
  from datetime import timezone
14
  import atexit # To gracefully shut down the scheduler
 
22
 
23
  api = HfApi()
24
 
25
+ def get_repo_counts(arxiv_id: str) -> dict:
26
+ """
27
+ Fetches the number of models, datasets, and Spaces linked to a given arxiv_id using Hugging Face API.
28
+ """
29
+ url = f"https://huggingface.co/api/arxiv/{arxiv_id}/repos"
30
+ try:
31
+ response = requests.get(url, timeout=10)
32
+ response.raise_for_status()
33
+ data = response.json()
34
+
35
+ models = data.get('models', [])
36
+ datasets_list = data.get('datasets', [])
37
+ spaces = data.get('spaces', [])
38
+
39
+ return {
40
+ 'models_count': len(models),
41
+ 'datasets_count': len(datasets_list),
42
+ 'spaces_count': len(spaces)
43
+ }
44
+ except requests.exceptions.RequestException as e:
45
+ logger.error(f"Error fetching repo counts for {arxiv_id}: {e}")
46
+ return {
47
+ 'models_count': 0,
48
+ 'datasets_count': 0,
49
+ 'spaces_count': 0
50
+ }
51
+
52
  def get_df() -> pd.DataFrame:
53
  """
54
  Loads and merges the papers and stats datasets, preprocesses the data by removing unnecessary columns,
55
+ adds a 'paper_page' link for each paper, and fetches counts of models, datasets, and Spaces linked to each paper.
56
  """
57
  try:
58
  # Load datasets
 
80
  info = row.copy()
81
  if "abstract" in info:
82
  del info["abstract"]
83
+ # Fetch repo counts
84
+ arxiv_id = info.get("arxiv_id", "")
85
+ if arxiv_id:
86
+ counts = get_repo_counts(arxiv_id)
87
+ info.update(counts)
88
+ else:
89
+ info.update({
90
+ 'models_count': 0,
91
+ 'datasets_count': 0,
92
+ 'spaces_count': 0
93
+ })
94
  paper_info.append(info)
95
  df_prepared = pd.DataFrame(paper_info)
96
 
 
109
  """
110
  Converts raw DataFrame rows into a prettified format suitable for display.
111
  """
112
+ REQUIRED_COLUMNS = [
113
+ "arxiv_id", "date_display", "date", "paper_page",
114
+ "title", "github", "πŸ‘", "πŸ’¬",
115
+ "models_count", "datasets_count", "spaces_count"
116
+ ]
117
 
118
  @staticmethod
119
  def get_github_link(link: str) -> str:
 
140
  "github": Prettifier.get_github_link(row.get("github", "")),
141
  "πŸ‘": row.get("upvotes", 0),
142
  "πŸ’¬": row.get("num_comments", 0),
143
+ "models_count": row.get("models_count", 0),
144
+ "datasets_count": row.get("datasets_count", 0),
145
+ "spaces_count": row.get("spaces_count", 0),
146
  }
147
  new_rows.append(new_row)
148
 
 
166
  ["github", "markdown"],
167
  ["πŸ‘", "number"],
168
  ["πŸ’¬", "number"],
169
+ ["models_count", "number"],
170
+ ["datasets_count", "number"],
171
+ ["spaces_count", "number"],
172
  ]
173
 
174
  def __init__(self, df: pd.DataFrame):
 
261
  df_sorted['date_parsed'] = pd.to_datetime(df_sorted['date'], errors='coerce').dt.tz_localize(timezone.utc, ambiguous='NaT', nonexistent='NaT')
262
  df_sorted = df_sorted[df_sorted['date_parsed'] >= time_threshold]
263
  df_sorted = df_sorted.sort_values(by='upvotes', ascending=False).drop(columns=['date_parsed'])
264
+ elif self.sort_method == "most_models":
265
+ df_sorted = df.sort_values(by='models_count', ascending=False)
266
+ elif self.sort_method == "most_datasets":
267
+ df_sorted = df.sort_values(by='datasets_count', ascending=False)
268
+ elif self.sort_method == "most_spaces":
269
+ df_sorted = df.sort_values(by='spaces_count', ascending=False)
270
  else:
271
  df_sorted = df
272
 
 
274
  self.paper_list.df_prettified = self.paper_list._prettifier(self.paper_list.df_raw).loc[:, self.paper_list.column_names]
275
  self.total_pages = max((len(self.paper_list.df_raw) + self.papers_per_page - 1) // self.papers_per_page, 1)
276
  self.current_page = 1
277
+ logger.info(f"Papers sorted by {self.sort_method}. Total pages: {self.total_pages}")
278
 
279
  def set_sort_method(self, method, time_frame=None):
280
  """
281
+ Sets the sort method ('hot', 'new', 'top', 'most_models', 'most_datasets', 'most_spaces') and re-sorts the papers.
282
  If 'top' is selected, also sets the time frame.
283
  """
284
+ valid_methods = ["hot", "new", "top", "most_models", "most_datasets", "most_spaces"]
285
+ if method not in valid_methods:
286
  method = "hot"
287
  logger.info(f"Setting sort method to: {method}")
288
  self.sort_method = method
 
319
  url = f"https://huggingface.co/papers/{paper_id}"
320
  upvotes = row.get('πŸ‘', 0)
321
  comments = row.get('πŸ’¬', 0)
322
+ models = row.get('models_count', 0)
323
+ datasets_count = row.get('datasets_count', 0)
324
+ spaces = row.get('spaces_count', 0)
325
  date_str = row.get('date', datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d"))
326
  try:
327
  published_time = datetime.datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
 
342
  <td colspan="1"></td>
343
  <td class="subtext">
344
  <span class="score">{upvotes} upvotes</span><br>
345
+ {time_ago} | <a href="#">{comments} comments</a><br>
346
+ Models: {models} | Datasets: {datasets_count} | Spaces: {spaces}
347
  </td>
348
  </tr>
349
  <tr style="height:5px"></tr>
 
386
  Initializes the PaperList and PaperManager with the current DataFrame.
387
  """
388
  df = get_df()
389
+ if df.empty:
390
+ logger.warning("Initialized with an empty DataFrame.")
391
  paper_list = PaperList(df)
392
  manager = PaperManager(paper_list)
393
+ logger.info("PaperManager initialized.")
394
  return manager.get_current_page_papers() # Return HTML string instead of the manager object
395
 
396
 
397
  paper_manager = None # Initialize globally
398
 
 
399
  def setup_paper_manager():
400
  """
401
  Sets up the global PaperManager instance.
 
451
  Changes the sort method and, if 'top' is selected, sets the time frame.
452
  """
453
  logger.info(f"Changing sort method to: {method} with time frame: {time_frame}")
454
+ if method.lower() in ["most_models", "most_datasets", "most_spaces"]:
455
+ paper_manager.set_sort_method(method.lower())
456
+ elif method.lower() == "top":
457
  paper_manager.set_sort_method(method.lower(), time_frame)
458
  else:
459
  paper_manager.set_sort_method(method.lower())
 
594
  }
595
  """
596
 
597
+
598
  # --- Initialize Gradio Blocks ---
599
 
600
  demo = gr.Blocks(css=css)
 
628
  # Sort Options and Time Frame (conditionally visible)
629
  with gr.Row():
630
  sort_radio = gr.Radio(
631
+ choices=["Hot", "New", "Top", "Most Models", "Most Datasets", "Most Spaces"],
632
  value="Hot",
633
  label="Sort By",
634
  interactive=True