Spaces:
Running
on
Zero
Running
on
Zero
Enhance sampling logic in app.py to support batched sampling for target sizes over 10k, ensuring unique records are collected. Update network_utils.py to set a solid white background for citation graphs and disable transparency in saved figures.
Browse files- app.py +59 -11
- network_utils.py +4 -1
app.py
CHANGED
@@ -310,6 +310,10 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
310 |
plot_time_checkbox = plot_type_dropdown == "Time-based coloring"
|
311 |
treat_as_categorical_checkbox = plot_type_dropdown == "Categorical coloring"
|
312 |
|
|
|
|
|
|
|
|
|
313 |
# Helper function to generate error responses
|
314 |
def create_error_response(error_message):
|
315 |
return [
|
@@ -452,19 +456,63 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
|
|
452 |
try:
|
453 |
# Check if PyAlex sample method exists and works
|
454 |
if hasattr(query, 'sample'):
|
455 |
-
sampled_query = query.sample(target_size, seed=seed_int)
|
456 |
-
|
457 |
-
# IMPORTANT: When using sample(), must use method='page' for pagination!
|
458 |
sampled_records = []
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
|
467 |
-
print(f'PyAlex sampling successful: got {len(sampled_records)} records')
|
468 |
else:
|
469 |
raise AttributeError("sample method not available")
|
470 |
|
|
|
310 |
plot_time_checkbox = plot_type_dropdown == "Time-based coloring"
|
311 |
treat_as_categorical_checkbox = plot_type_dropdown == "Categorical coloring"
|
312 |
|
313 |
+
# Initialize variables used later across branches
|
314 |
+
urls = []
|
315 |
+
query_indices = []
|
316 |
+
|
317 |
# Helper function to generate error responses
|
318 |
def create_error_response(error_message):
|
319 |
return [
|
|
|
456 |
try:
|
457 |
# Check if PyAlex sample method exists and works
|
458 |
if hasattr(query, 'sample'):
|
|
|
|
|
|
|
459 |
sampled_records = []
|
460 |
+
seen_ids = set() # Track IDs to avoid duplicates
|
461 |
+
|
462 |
+
# If target_size > 10k, do batched sampling
|
463 |
+
if target_size > 10000:
|
464 |
+
batch_size = 9998 # Use 9998 to stay safely under 10k limit
|
465 |
+
remaining = target_size
|
466 |
+
batch_num = 1
|
467 |
+
|
468 |
+
print(f'Target size {target_size} > 10k, using batched sampling with batch size {batch_size}')
|
469 |
+
|
470 |
+
while remaining > 0 and len(sampled_records) < target_size:
|
471 |
+
current_batch_size = min(batch_size, remaining)
|
472 |
+
batch_seed = seed_int + batch_num # Different seed for each batch
|
473 |
+
|
474 |
+
print(f'Batch {batch_num}: requesting {current_batch_size} samples (seed={batch_seed})')
|
475 |
+
|
476 |
+
# Sample this batch
|
477 |
+
batch_query = query.sample(current_batch_size, seed=batch_seed)
|
478 |
+
|
479 |
+
batch_records = []
|
480 |
+
batch_count = 0
|
481 |
+
for page in batch_query.paginate(per_page=200, method='page', n_max=None):
|
482 |
+
for record in page:
|
483 |
+
# Check for duplicates using OpenAlex ID
|
484 |
+
record_id = record.get('id', '')
|
485 |
+
if record_id not in seen_ids:
|
486 |
+
seen_ids.add(record_id)
|
487 |
+
batch_records.append(record)
|
488 |
+
batch_count += 1
|
489 |
+
|
490 |
+
sampled_records.extend(batch_records)
|
491 |
+
remaining -= len(batch_records)
|
492 |
+
batch_num += 1
|
493 |
+
|
494 |
+
print(f'Batch {batch_num-1} complete: got {len(batch_records)} unique records ({len(sampled_records)}/{target_size} total)')
|
495 |
+
|
496 |
+
progress(0.1 + (0.15 * len(sampled_records) / target_size),
|
497 |
+
desc=f"Batched sampling from query {i+1}/{len(urls)}... ({len(sampled_records)}/{target_size})")
|
498 |
+
|
499 |
+
# Safety check to avoid infinite loops
|
500 |
+
if batch_num > 20: # Max 20 batches (should handle up to ~200k samples)
|
501 |
+
print("Warning: Maximum batch limit reached, stopping sampling")
|
502 |
+
break
|
503 |
+
else:
|
504 |
+
# Single batch sampling for <= 10k
|
505 |
+
sampled_query = query.sample(target_size, seed=seed_int)
|
506 |
+
|
507 |
+
records_count = 0
|
508 |
+
for page in sampled_query.paginate(per_page=200, method='page', n_max=None):
|
509 |
+
for record in page:
|
510 |
+
sampled_records.append(record)
|
511 |
+
records_count += 1
|
512 |
+
progress(0.1 + (0.15 * records_count / target_size),
|
513 |
+
desc=f"Getting sampled data from query {i+1}/{len(urls)}... ({records_count}/{target_size})")
|
514 |
|
515 |
+
print(f'PyAlex sampling successful: got {len(sampled_records)} records (requested {target_size})')
|
516 |
else:
|
517 |
raise AttributeError("sample method not available")
|
518 |
|
network_utils.py
CHANGED
@@ -42,6 +42,9 @@ def draw_citation_graph(G, bundle_edges=False, path=None, min_max_coordinates=No
|
|
42 |
for node in G.nodes():
|
43 |
pos[node] = (G.nodes[node]['X'], G.nodes[node]['Y'])
|
44 |
fig, ax = plt.subplots(figsize=(20, 20))
|
|
|
|
|
|
|
45 |
plt.margins(0, 0) # Remove margins
|
46 |
if bundle_edges:
|
47 |
# Turning color into rgb
|
@@ -64,4 +67,4 @@ def draw_citation_graph(G, bundle_edges=False, path=None, min_max_coordinates=No
|
|
64 |
plt.ylim(min_max_coordinates[2], min_max_coordinates[3])
|
65 |
|
66 |
if path is not None:
|
67 |
-
plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=800, transparent=
|
|
|
42 |
for node in G.nodes():
|
43 |
pos[node] = (G.nodes[node]['X'], G.nodes[node]['Y'])
|
44 |
fig, ax = plt.subplots(figsize=(20, 20))
|
45 |
+
# Ensure a solid white background (no transparency)
|
46 |
+
fig.patch.set_facecolor('green')
|
47 |
+
ax.set_facecolor('green')
|
48 |
plt.margins(0, 0) # Remove margins
|
49 |
if bundle_edges:
|
50 |
# Turning color into rgb
|
|
|
67 |
plt.ylim(min_max_coordinates[2], min_max_coordinates[3])
|
68 |
|
69 |
if path is not None:
|
70 |
+
plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=800, transparent=False)
|