MaxNoichl commited on
Commit
2dc529c
·
1 Parent(s): f1adc40

Enhance sampling logic in app.py to support batched sampling for target sizes over 10k, ensuring unique records are collected. Update network_utils.py to set a solid white background for citation graphs and disable transparency in saved figures.

Browse files
Files changed (2) hide show
  1. app.py +59 -11
  2. network_utils.py +4 -1
app.py CHANGED
@@ -310,6 +310,10 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
310
  plot_time_checkbox = plot_type_dropdown == "Time-based coloring"
311
  treat_as_categorical_checkbox = plot_type_dropdown == "Categorical coloring"
312
 
 
 
 
 
313
  # Helper function to generate error responses
314
  def create_error_response(error_message):
315
  return [
@@ -452,19 +456,63 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
452
  try:
453
  # Check if PyAlex sample method exists and works
454
  if hasattr(query, 'sample'):
455
- sampled_query = query.sample(target_size, seed=seed_int)
456
-
457
- # IMPORTANT: When using sample(), must use method='page' for pagination!
458
  sampled_records = []
459
- records_count = 0
460
- for page in sampled_query.paginate(per_page=200, method='page', n_max=None):
461
- for record in page:
462
- sampled_records.append(record)
463
- records_count += 1
464
- progress(0.1 + (0.15 * records_count / target_size),
465
- desc=f"Getting sampled data from query {i+1}/{len(urls)}... ({records_count}/{target_size})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
- print(f'PyAlex sampling successful: got {len(sampled_records)} records')
468
  else:
469
  raise AttributeError("sample method not available")
470
 
 
310
  plot_time_checkbox = plot_type_dropdown == "Time-based coloring"
311
  treat_as_categorical_checkbox = plot_type_dropdown == "Categorical coloring"
312
 
313
+ # Initialize variables used later across branches
314
+ urls = []
315
+ query_indices = []
316
+
317
  # Helper function to generate error responses
318
  def create_error_response(error_message):
319
  return [
 
456
  try:
457
  # Check if PyAlex sample method exists and works
458
  if hasattr(query, 'sample'):
 
 
 
459
  sampled_records = []
460
+ seen_ids = set() # Track IDs to avoid duplicates
461
+
462
+ # If target_size > 10k, do batched sampling
463
+ if target_size > 10000:
464
+ batch_size = 9998 # Use 9998 to stay safely under 10k limit
465
+ remaining = target_size
466
+ batch_num = 1
467
+
468
+ print(f'Target size {target_size} > 10k, using batched sampling with batch size {batch_size}')
469
+
470
+ while remaining > 0 and len(sampled_records) < target_size:
471
+ current_batch_size = min(batch_size, remaining)
472
+ batch_seed = seed_int + batch_num # Different seed for each batch
473
+
474
+ print(f'Batch {batch_num}: requesting {current_batch_size} samples (seed={batch_seed})')
475
+
476
+ # Sample this batch
477
+ batch_query = query.sample(current_batch_size, seed=batch_seed)
478
+
479
+ batch_records = []
480
+ batch_count = 0
481
+ for page in batch_query.paginate(per_page=200, method='page', n_max=None):
482
+ for record in page:
483
+ # Check for duplicates using OpenAlex ID
484
+ record_id = record.get('id', '')
485
+ if record_id not in seen_ids:
486
+ seen_ids.add(record_id)
487
+ batch_records.append(record)
488
+ batch_count += 1
489
+
490
+ sampled_records.extend(batch_records)
491
+ remaining -= len(batch_records)
492
+ batch_num += 1
493
+
494
+ print(f'Batch {batch_num-1} complete: got {len(batch_records)} unique records ({len(sampled_records)}/{target_size} total)')
495
+
496
+ progress(0.1 + (0.15 * len(sampled_records) / target_size),
497
+ desc=f"Batched sampling from query {i+1}/{len(urls)}... ({len(sampled_records)}/{target_size})")
498
+
499
+ # Safety check to avoid infinite loops
500
+ if batch_num > 20: # Max 20 batches (should handle up to ~200k samples)
501
+ print("Warning: Maximum batch limit reached, stopping sampling")
502
+ break
503
+ else:
504
+ # Single batch sampling for <= 10k
505
+ sampled_query = query.sample(target_size, seed=seed_int)
506
+
507
+ records_count = 0
508
+ for page in sampled_query.paginate(per_page=200, method='page', n_max=None):
509
+ for record in page:
510
+ sampled_records.append(record)
511
+ records_count += 1
512
+ progress(0.1 + (0.15 * records_count / target_size),
513
+ desc=f"Getting sampled data from query {i+1}/{len(urls)}... ({records_count}/{target_size})")
514
 
515
+ print(f'PyAlex sampling successful: got {len(sampled_records)} records (requested {target_size})')
516
  else:
517
  raise AttributeError("sample method not available")
518
 
network_utils.py CHANGED
@@ -42,6 +42,9 @@ def draw_citation_graph(G, bundle_edges=False, path=None, min_max_coordinates=No
42
  for node in G.nodes():
43
  pos[node] = (G.nodes[node]['X'], G.nodes[node]['Y'])
44
  fig, ax = plt.subplots(figsize=(20, 20))
 
 
 
45
  plt.margins(0, 0) # Remove margins
46
  if bundle_edges:
47
  # Turning color into rgb
@@ -64,4 +67,4 @@ def draw_citation_graph(G, bundle_edges=False, path=None, min_max_coordinates=No
64
  plt.ylim(min_max_coordinates[2], min_max_coordinates[3])
65
 
66
  if path is not None:
67
- plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=800, transparent=True)
 
42
  for node in G.nodes():
43
  pos[node] = (G.nodes[node]['X'], G.nodes[node]['Y'])
44
  fig, ax = plt.subplots(figsize=(20, 20))
45
+ # Ensure a solid white background (no transparency)
46
+ fig.patch.set_facecolor('green')
47
+ ax.set_facecolor('green')
48
  plt.margins(0, 0) # Remove margins
49
  if bundle_edges:
50
  # Turning color into rgb
 
67
  plt.ylim(min_max_coordinates[2], min_max_coordinates[3])
68
 
69
  if path is not None:
70
+ plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=800, transparent=False)