MaxNoichl commited on
Commit
e5dee3b
·
1 Parent(s): f895c88

Updated sizing small aestheztic changes & fixes to sampling

Browse files
Files changed (2) hide show
  1. app.py +222 -85
  2. legend_builders.py +221 -0
app.py CHANGED
@@ -146,6 +146,14 @@ from network_utils import create_citation_graph, draw_citation_graph
146
  # Add colormap chooser imports
147
  from colormap_chooser import ColormapChooser, setup_colormaps
148
 
 
 
 
 
 
 
 
 
149
 
150
  # Configure OpenAlex
151
  pyalex.config.email = "[email protected]"
@@ -265,52 +273,6 @@ def create_embeddings(texts_to_embedd):
265
  """Create embeddings for the input texts using the loaded model."""
266
  return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
267
 
268
-
269
- def highlight_queries(text: str) -> str:
270
- """Split OpenAlex URLs on semicolons and display them as colored pills with readable names."""
271
- palette = [
272
- "#e8f4fd", "#fff2e8", "#f0f9e8", "#fdf2f8",
273
- "#f3e8ff", "#e8f8f5", "#fef7e8", "#f8f0e8"
274
- ]
275
-
276
- # Handle empty input
277
- if not text or not text.strip():
278
- return "<div style='padding: 10px; color: #666; font-style: italic;'>Enter OpenAlex URLs separated by semicolons to see query descriptions</div>"
279
-
280
- # Split URLs on semicolons and strip whitespace
281
- urls = [url.strip() for url in text.split(";") if url.strip()]
282
-
283
- if not urls:
284
- return "<div style='padding: 10px; color: #666; font-style: italic;'>No valid URLs found</div>"
285
-
286
- pills = []
287
- for i, url in enumerate(urls):
288
- color = palette[i % len(palette)]
289
- try:
290
- # Get readable name for the URL
291
- readable_name = openalex_url_to_readable_name(url)
292
- except Exception as e:
293
- print(f"Error processing URL {url}: {e}")
294
- readable_name = f"Query {i+1}"
295
-
296
- pills.append(
297
- f'<span style="background:{color};'
298
- 'padding: 8px 12px; margin: 4px; '
299
- 'border-radius: 12px; font-weight: 500;'
300
- 'display: inline-block; font-family: \'Roboto Condensed\', sans-serif;'
301
- 'border: 1px solid rgba(0,0,0,0.1); font-size: 14px;'
302
- 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
303
- f'{readable_name}</span>'
304
- )
305
-
306
- return (
307
- "<div style='padding: 8px 0;'>"
308
- "<div style='font-size: 12px; color: #666; margin-bottom: 6px; font-weight: 500;'>"
309
- f"{'Query' if len(urls) == 1 else 'Queries'} ({len(urls)}):</div>"
310
- "<div style='display: flex; flex-wrap: wrap; gap: 4px;'>"
311
- + "".join(pills) +
312
- "</div></div>"
313
- )
314
 
315
 
316
  def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
@@ -451,6 +413,7 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
451
  records = []
452
  query_indices = [] # Track which query each record comes from
453
  total_query_length = 0
 
454
 
455
  # Use first URL for filename
456
  first_query, first_params = openalex_url_to_pyalex_query(urls[0])
@@ -462,7 +425,17 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
462
  query, params = openalex_url_to_pyalex_query(url)
463
  query_length = query.count()
464
  total_query_length += query_length
465
- print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Use PyAlex sampling for random samples - much more efficient!
468
  if reduce_sample_checkbox and sample_reduction_method == "n random samples":
@@ -524,15 +497,23 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
524
  for idx, record in enumerate(sampled_records):
525
  records.append(record)
526
  query_indices.append(i)
527
- progress(0.1 + (0.2 * len(records) / total_query_length),
528
- desc=f"Processing sampled data from query {i+1}/{len(urls)}...")
 
 
 
 
529
  else:
530
  # Keep existing logic for "First n samples" and "All"
531
  target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
532
  records_per_query = 0
533
 
 
 
534
  should_break_current_query = False
535
- for page in query.paginate(per_page=200, n_max=None):
 
 
536
  # Add retry mechanism for processing each page
537
  max_retries = 5
538
  base_wait_time = 1 # Starting wait time in seconds
@@ -541,13 +522,24 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
541
  for retry_attempt in range(max_retries):
542
  try:
543
  for record in page:
 
 
 
 
 
 
544
  records.append(record)
545
  query_indices.append(i) # Track which query this record comes from
546
  records_per_query += 1
547
- progress(0.1 + (0.2 * len(records) / (total_query_length)),
548
- desc=f"Getting data from query {i+1}/{len(urls)}...")
 
 
 
 
549
 
550
  if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
 
551
  should_break_current_query = True
552
  break
553
  # If we get here without an exception, break the retry loop
@@ -560,13 +552,19 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
560
  time.sleep(wait_time)
561
  else:
562
  print(f"Maximum retries reached. Continuing with next page.")
 
 
 
 
563
 
564
  if should_break_current_query:
 
565
  break
566
  # Continue to next query - don't break out of the main query loop
567
  print(f"Query completed in {time.time() - start_time:.2f} seconds")
568
  print(f"Total records collected: {len(records)}")
569
- print(f"Expected from all queries: {total_query_length}")
 
570
  print(f"Sample method used: {sample_reduction_method}")
571
  print(f"Reduce sample enabled: {reduce_sample_checkbox}")
572
  if sample_reduction_method == "n random samples":
@@ -664,29 +662,62 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
664
  # Use categorical coloring for multiple queries
665
  print("Using categorical coloring for multiple queries")
666
 
667
- # Define a categorical colormap - using distinct colors
668
- categorical_colors = [
669
- '#e41a1c', # Red
670
- '#377eb8', # Blue
671
- '#4daf4a', # Green
672
- '#984ea3', # Purple
673
- '#ff7f00', # Orange
674
- '#ffff33', # Yellow
675
- '#a65628', # Brown
676
- '#f781bf', # Pink
677
- '#999999', # Gray
678
- '#66c2a5', # Teal
679
- '#fc8d62', # Light Orange
680
- '#8da0cb', # Light Blue
681
- '#e78ac3', # Light Pink
682
- '#a6d854', # Light Green
683
- '#ffd92f', # Light Yellow
684
- '#e5c494', # Beige
685
- '#b3b3b3', # Light Gray
686
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
 
688
  # Assign colors based on query_index
689
- unique_queries = sorted(records_df['query_index'].unique())
690
  query_color_map = {query_idx: categorical_colors[i % len(categorical_colors)]
691
  for i, query_idx in enumerate(unique_queries)}
692
 
@@ -699,18 +730,22 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
699
  # Use selected colormap if provided, otherwise default to haline
700
  if selected_colormap_name and selected_colormap_name.strip():
701
  try:
702
- cmap = plt.get_cmap(selected_colormap_name)
703
  except Exception as e:
704
  print(f"Warning: Could not load colormap '{selected_colormap_name}': {e}")
705
- cmap = colormaps.haline
706
  else:
707
- cmap = colormaps.haline
708
 
709
  if not locally_approximate_publication_date_checkbox:
710
  # Create color mapping based on publication years
711
  years = pd.to_numeric(records_df['publication_year'])
712
  norm = mcolors.Normalize(vmin=years.min(), vmax=years.max())
713
- records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in years]
 
 
 
 
714
 
715
  else:
716
  n_neighbors = 10 # Adjust this value to control smoothing
@@ -724,7 +759,11 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
724
  for idx in indices
725
  ])
726
  norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
727
- records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
 
 
 
 
728
  else:
729
  # No special coloring - use highlight color
730
  records_df['color'] = highlight_color
@@ -732,6 +771,13 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
732
  stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
733
  stacked_df = stacked_df.fillna("Unlabelled")
734
  stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()]
 
 
 
 
 
 
 
735
  extra_data = pd.DataFrame(stacked_df['doi'])
736
  print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
737
 
@@ -756,6 +802,94 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
756
  # Create a solid black colormap
757
  black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000'])
758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  plot = datamapplot.create_interactive_plot(
760
  stacked_df[['x','y']].values,
761
  np.array(stacked_df['cluster_2_labels']),
@@ -763,13 +897,15 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
763
 
764
  hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
765
  marker_color_array=stacked_df['color'],
 
766
  use_medoids=True, # Switch back once efficient mediod caclulation comes out!
767
  width=1000,
768
  height=1000,
 
769
  point_radius_min_pixels=1,
770
  text_outline_width=5,
771
  point_hover_color=highlight_color,
772
- point_radius_max_pixels=7,
773
  cmap=black_cmap,
774
  background_image=graph_file_name if citation_graph_checkbox else None,
775
  #color_label_text=False,
@@ -779,7 +915,8 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
779
  tooltip_font_family="Roboto Condensed",
780
  extra_point_data=extra_data,
781
  on_click="window.open(`{doi}`)",
782
- custom_css=DATAMAP_CUSTOM_CSS,
 
783
  initial_zoom_fraction=.8,
784
  enable_search=False,
785
  offline_mode=False
@@ -801,8 +938,8 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
801
  export_df['query_index'] = records_df['query_index']
802
  export_df['query_label'] = records_df['query_label']
803
 
804
- if locally_approximate_publication_date_checkbox and plot_type_dropdown == "Time-based coloring":
805
- export_df['approximate_publication_year'] = local_years
806
  export_df.to_csv(csv_file_path, index=False)
807
 
808
  if download_png_checkbox:
@@ -878,11 +1015,11 @@ def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_c
878
  else:
879
  static_cmap = colormaps.haline
880
 
881
- if locally_approximate_publication_date_checkbox:
882
  scatter = plt.scatter(
883
  umap_embeddings[:,0],
884
  umap_embeddings[:,1],
885
- c=local_years,
886
  cmap=static_cmap,
887
  alpha=0.8,
888
  s=point_size
 
146
  # Add colormap chooser imports
147
  from colormap_chooser import ColormapChooser, setup_colormaps
148
 
149
+ # Add legend builder imports
150
+ try:
151
+ from legend_builders import continuous_legend_html_css, categorical_legend_html_css
152
+ HAS_LEGEND_BUILDERS = True
153
+ except ImportError:
154
+ print("Warning: legend_builders.py not found. Legends will be disabled.")
155
+ HAS_LEGEND_BUILDERS = False
156
+
157
 
158
  # Configure OpenAlex
159
  pyalex.config.email = "[email protected]"
 
273
  """Create embeddings for the input texts using the loaded model."""
274
  return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
 
278
  def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox,
 
413
  records = []
414
  query_indices = [] # Track which query each record comes from
415
  total_query_length = 0
416
+ expected_download_count = 0 # Track expected number of records to download for progress
417
 
418
  # Use first URL for filename
419
  first_query, first_params = openalex_url_to_pyalex_query(urls[0])
 
425
  query, params = openalex_url_to_pyalex_query(url)
426
  query_length = query.count()
427
  total_query_length += query_length
428
+
429
+ # Calculate expected download count for this query
430
+ if reduce_sample_checkbox and sample_reduction_method == "First n samples":
431
+ expected_for_this_query = min(sample_size_slider, query_length)
432
+ elif reduce_sample_checkbox and sample_reduction_method == "n random samples":
433
+ expected_for_this_query = min(sample_size_slider, query_length)
434
+ else: # "All"
435
+ expected_for_this_query = query_length
436
+
437
+ expected_download_count += expected_for_this_query
438
+ print(f'Requesting {query_length} entries from query {i+1}/{len(urls)} (expecting to download {expected_for_this_query})...')
439
 
440
  # Use PyAlex sampling for random samples - much more efficient!
441
  if reduce_sample_checkbox and sample_reduction_method == "n random samples":
 
497
  for idx, record in enumerate(sampled_records):
498
  records.append(record)
499
  query_indices.append(i)
500
+ # Safe progress calculation
501
+ if expected_download_count > 0:
502
+ progress_val = 0.1 + (0.2 * len(records) / expected_download_count)
503
+ else:
504
+ progress_val = 0.1
505
+ progress(progress_val, desc=f"Processing sampled data from query {i+1}/{len(urls)}...")
506
  else:
507
  # Keep existing logic for "First n samples" and "All"
508
  target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
509
  records_per_query = 0
510
 
511
+ print(f"Query {i+1}: target_size={target_size}, query_length={query_length}, method={sample_reduction_method}")
512
+
513
  should_break_current_query = False
514
+ # For "First n samples", limit the maximum records fetched to avoid over-downloading
515
+ max_records_to_fetch = target_size if reduce_sample_checkbox and sample_reduction_method == "First n samples" else None
516
+ for page in query.paginate(per_page=200, n_max=max_records_to_fetch):
517
  # Add retry mechanism for processing each page
518
  max_retries = 5
519
  base_wait_time = 1 # Starting wait time in seconds
 
522
  for retry_attempt in range(max_retries):
523
  try:
524
  for record in page:
525
+ # Safety check: don't process if we've already reached target
526
+ if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
527
+ print(f"Reached target size before processing: {records_per_query}/{target_size}, breaking from download")
528
+ should_break_current_query = True
529
+ break
530
+
531
  records.append(record)
532
  query_indices.append(i) # Track which query this record comes from
533
  records_per_query += 1
534
+ # Safe progress calculation
535
+ if expected_download_count > 0:
536
+ progress_val = 0.1 + (0.2 * len(records) / expected_download_count)
537
+ else:
538
+ progress_val = 0.1
539
+ progress(progress_val, desc=f"Getting data from query {i+1}/{len(urls)}...")
540
 
541
  if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
542
+ print(f"Reached target size: {records_per_query}/{target_size}, breaking from download")
543
  should_break_current_query = True
544
  break
545
  # If we get here without an exception, break the retry loop
 
552
  time.sleep(wait_time)
553
  else:
554
  print(f"Maximum retries reached. Continuing with next page.")
555
+
556
+ # Break out of retry loop if we've reached target
557
+ if should_break_current_query:
558
+ break
559
 
560
  if should_break_current_query:
561
+ print(f"Successfully broke from page loop for query {i+1}")
562
  break
563
  # Continue to next query - don't break out of the main query loop
564
  print(f"Query completed in {time.time() - start_time:.2f} seconds")
565
  print(f"Total records collected: {len(records)}")
566
+ print(f"Expected to download: {expected_download_count}")
567
+ print(f"Available from all queries: {total_query_length}")
568
  print(f"Sample method used: {sample_reduction_method}")
569
  print(f"Reduce sample enabled: {reduce_sample_checkbox}")
570
  if sample_reduction_method == "n random samples":
 
662
  # Use categorical coloring for multiple queries
663
  print("Using categorical coloring for multiple queries")
664
 
665
+ # Get colors from selected colormap or use default categorical colors
666
+ unique_queries = sorted(records_df['query_index'].unique())
667
+ num_queries = len(unique_queries)
668
+
669
+ if selected_colormap_name and selected_colormap_name.strip():
670
+ try:
671
+ # Use selected colormap to generate distinct colors
672
+ categorical_cmap = plt.get_cmap(selected_colormap_name)
673
+ # Sample colors evenly spaced across the colormap
674
+ categorical_colors = [mcolors.to_hex(categorical_cmap(i / max(1, num_queries - 1)))
675
+ for i in range(num_queries)]
676
+ except Exception as e:
677
+ print(f"Warning: Could not load colormap '{selected_colormap_name}' for categorical coloring: {e}")
678
+ # Fallback to default categorical colors
679
+ categorical_colors = [
680
+ '#e41a1c', # Red
681
+ '#377eb8', # Blue
682
+ '#4daf4a', # Green
683
+ '#984ea3', # Purple
684
+ '#ff7f00', # Orange
685
+ '#ffff33', # Yellow
686
+ '#a65628', # Brown
687
+ '#f781bf', # Pink
688
+ '#999999', # Gray
689
+ '#66c2a5', # Teal
690
+ '#fc8d62', # Light Orange
691
+ '#8da0cb', # Light Blue
692
+ '#e78ac3', # Light Pink
693
+ '#a6d854', # Light Green
694
+ '#ffd92f', # Light Yellow
695
+ '#e5c494', # Beige
696
+ '#b3b3b3', # Light Gray
697
+ ]
698
+ else:
699
+ # Use default categorical colors
700
+ categorical_colors = [
701
+ '#e41a1c', # Red
702
+ '#377eb8', # Blue
703
+ '#4daf4a', # Green
704
+ '#984ea3', # Purple
705
+ '#ff7f00', # Orange
706
+ '#ffff33', # Yellow
707
+ '#a65628', # Brown
708
+ '#f781bf', # Pink
709
+ '#999999', # Gray
710
+ '#66c2a5', # Teal
711
+ '#fc8d62', # Light Orange
712
+ '#8da0cb', # Light Blue
713
+ '#e78ac3', # Light Pink
714
+ '#a6d854', # Light Green
715
+ '#ffd92f', # Light Yellow
716
+ '#e5c494', # Beige
717
+ '#b3b3b3', # Light Gray
718
+ ]
719
 
720
  # Assign colors based on query_index
 
721
  query_color_map = {query_idx: categorical_colors[i % len(categorical_colors)]
722
  for i, query_idx in enumerate(unique_queries)}
723
 
 
730
  # Use selected colormap if provided, otherwise default to haline
731
  if selected_colormap_name and selected_colormap_name.strip():
732
  try:
733
+ time_cmap = plt.get_cmap(selected_colormap_name)
734
  except Exception as e:
735
  print(f"Warning: Could not load colormap '{selected_colormap_name}': {e}")
736
+ time_cmap = colormaps.haline
737
  else:
738
+ time_cmap = colormaps.haline
739
 
740
  if not locally_approximate_publication_date_checkbox:
741
  # Create color mapping based on publication years
742
  years = pd.to_numeric(records_df['publication_year'])
743
  norm = mcolors.Normalize(vmin=years.min(), vmax=years.max())
744
+ records_df['color'] = [mcolors.to_hex(time_cmap(norm(year))) for year in years]
745
+ # Store for legend generation
746
+ years_for_legend = years
747
+ legend_label = "Publication Year"
748
+ legend_cmap = time_cmap
749
 
750
  else:
751
  n_neighbors = 10 # Adjust this value to control smoothing
 
759
  for idx in indices
760
  ])
761
  norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
762
+ records_df['color'] = [mcolors.to_hex(time_cmap(norm(year))) for year in local_years]
763
+ # Store for legend generation
764
+ years_for_legend = local_years
765
+ legend_label = "Approx. Year"
766
+ legend_cmap = time_cmap
767
  else:
768
  # No special coloring - use highlight color
769
  records_df['color'] = highlight_color
 
771
  stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
772
  stacked_df = stacked_df.fillna("Unlabelled")
773
  stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()]
774
+
775
+ # Create marker size array: basemap points = 2, query result points = 4
776
+ marker_sizes = np.concatenate([
777
+ np.full(len(basedata_df), 1.), # Basemap points
778
+ np.full(len(records_df), 2.5) # Query result points
779
+ ])
780
+
781
  extra_data = pd.DataFrame(stacked_df['doi'])
782
  print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
783
 
 
802
  # Create a solid black colormap
803
  black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000'])
804
 
805
+ # Generate legends based on plot type
806
+ custom_html = ""
807
+ legend_css = ""
808
+
809
+ if HAS_LEGEND_BUILDERS:
810
+ if treat_as_categorical_checkbox and has_multiple_queries:
811
+ # Create categorical legend for multiple queries
812
+ unique_queries = sorted(records_df['query_index'].unique())
813
+ color_mapping = {}
814
+
815
+ # Get readable names for each query URL
816
+ for i, query_idx in enumerate(unique_queries):
817
+ try:
818
+ if query_idx < len(urls):
819
+ readable_name = openalex_url_to_readable_name(urls[query_idx])
820
+ # Truncate long names for legend display
821
+ if len(readable_name) > 25:
822
+ readable_name = readable_name[:22] + "..."
823
+ else:
824
+ readable_name = f"Query {query_idx + 1}"
825
+ except Exception:
826
+ readable_name = f"Query {query_idx + 1}"
827
+
828
+ color_mapping[readable_name] = query_color_map[query_idx]
829
+
830
+ legend_html, legend_css = categorical_legend_html_css(
831
+ color_mapping,
832
+ title="Queries" if len(color_mapping) > 1 else "Query",
833
+ anchor="top-left",
834
+ container_id="dmp-query-legend"
835
+ )
836
+ custom_html += legend_html
837
+
838
+ elif plot_time_checkbox and 'years_for_legend' in locals():
839
+ # Create continuous legend for time-based coloring using the stored variables
840
+ # Create ticks every 5 years within the range, ignoring endpoints
841
+ year_min, year_max = int(years_for_legend.min()), int(years_for_legend.max())
842
+ year_range = year_max - year_min
843
+
844
+ # Find the first multiple of 5 that's greater than year_min
845
+ first_tick = ((year_min // 5) + 1) * 5
846
+
847
+ # Generate ticks every 5 years until we reach year_max
848
+ ticks = []
849
+ current_tick = first_tick
850
+ while current_tick < year_max:
851
+ ticks.append(current_tick)
852
+ current_tick += 5
853
+
854
+ # For ranges under 15 years, include both endpoints
855
+ if year_range < 15:
856
+ if not ticks:
857
+ # No 5-year ticks, just show endpoints
858
+ ticks = [year_min, year_max]
859
+ else:
860
+ # Add endpoints to existing 5-year ticks
861
+ if year_min not in ticks:
862
+ ticks.insert(0, year_min)
863
+ if year_max not in ticks:
864
+ ticks.append(year_max)
865
+
866
+ legend_html, legend_css = continuous_legend_html_css(
867
+ legend_cmap,
868
+ year_min,
869
+ year_max,
870
+ ticks=ticks,
871
+ label=legend_label,
872
+ anchor="top-right",
873
+ container_id="dmp-year-legend"
874
+ )
875
+ custom_html += legend_html
876
+
877
+ # Add custom CSS to make legend titles equally large and bold
878
+ legend_title_css = """
879
+ /* Make all legend titles equally large and bold */
880
+ #dmp-query-legend .legend-title,
881
+ #dmp-year-legend .colorbar-label {
882
+ font-size: 16px !important;
883
+ font-weight: bold !important;
884
+ font-family: 'Roboto Condensed', sans-serif !important;
885
+ }
886
+ """
887
+
888
+ # Combine legend CSS with existing custom CSS
889
+ combined_css = DATAMAP_CUSTOM_CSS + "\n" + legend_css + "\n" + legend_title_css
890
+
891
+
892
+
893
  plot = datamapplot.create_interactive_plot(
894
  stacked_df[['x','y']].values,
895
  np.array(stacked_df['cluster_2_labels']),
 
897
 
898
  hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
899
  marker_color_array=stacked_df['color'],
900
+ marker_size_array=marker_sizes,
901
  use_medoids=True, # Switch back once efficient mediod caclulation comes out!
902
  width=1000,
903
  height=1000,
904
+ # point_size_scale=1.5,
905
  point_radius_min_pixels=1,
906
  text_outline_width=5,
907
  point_hover_color=highlight_color,
908
+ point_radius_max_pixels=5,
909
  cmap=black_cmap,
910
  background_image=graph_file_name if citation_graph_checkbox else None,
911
  #color_label_text=False,
 
915
  tooltip_font_family="Roboto Condensed",
916
  extra_point_data=extra_data,
917
  on_click="window.open(`{doi}`)",
918
+ custom_html=custom_html,
919
+ custom_css=combined_css,
920
  initial_zoom_fraction=.8,
921
  enable_search=False,
922
  offline_mode=False
 
938
  export_df['query_index'] = records_df['query_index']
939
  export_df['query_label'] = records_df['query_label']
940
 
941
+ if locally_approximate_publication_date_checkbox and plot_type_dropdown == "Time-based coloring" and 'years_for_legend' in locals():
942
+ export_df['approximate_publication_year'] = years_for_legend
943
  export_df.to_csv(csv_file_path, index=False)
944
 
945
  if download_png_checkbox:
 
1015
  else:
1016
  static_cmap = colormaps.haline
1017
 
1018
+ if locally_approximate_publication_date_checkbox and 'years_for_legend' in locals():
1019
  scatter = plt.scatter(
1020
  umap_embeddings[:,0],
1021
  umap_embeddings[:,1],
1022
+ c=years_for_legend,
1023
  cmap=static_cmap,
1024
  alpha=0.8,
1025
  s=point_size
legend_builders.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """legend_builders.py
2
+ ====================
3
+ Minimal‑dependency helpers that generate **static** legend HTML + CSS matching
4
+ DataMapPlot’s own class names. Drop the returned strings straight into
5
+ ``create_interactive_plot(custom_html=…, custom_css=…)``.
6
+
7
+ Highlights
8
+ ----------
9
+ * **continuous_legend_html_css** – full control over ticks, label, size &
10
+ absolute position (via an *anchor* keyword).
11
+ * **categorical_legend_html_css** – swatch legend with optional title, flexible
12
+ anchor, row/column layout and custom swatch size.
13
+
14
+ Both helpers return ``(html, css)`` so you can concatenate multiple legends.
15
+ No JavaScript is injected – they render statically but look native. If you
16
+ later add JS (e.g. DMP’s `ColorLegend` behaviour), the class names already fit.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from typing import Dict, List, Sequence, Tuple, Union
22
+ from datetime import datetime, date
23
+ import matplotlib.cm as _cm
24
+ from matplotlib.colors import to_hex, to_rgb
25
+
26
+ Colour = Union[str, tuple]
27
+ __all__ = ["continuous_legend_html_css", "categorical_legend_html_css"]
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # helpers
31
+ # ---------------------------------------------------------------------------
32
+
33
+ def _hex(c: Colour) -> str:
34
+ """Convert *c* to #RRGGBB hex (handles any Matplotlib‑parsable colour)."""
35
+ return c if isinstance(c, str) else to_hex(to_rgb(c))
36
+
37
+
38
+ def _gradient(cmap: Union[str, _cm.Colormap, Sequence[str]], *, vertical: bool = True) -> str:
39
+ """Return a CSS linear‑gradient from a Matplotlib cmap or explicit colour list."""
40
+ if isinstance(cmap, (list, tuple)):
41
+ stops = [_hex(c) for c in cmap]
42
+ else:
43
+ cmap = _cm.get_cmap(cmap) if isinstance(cmap, str) else cmap
44
+ stops = [to_hex(cmap(i / 255)) for i in range(256)]
45
+ direction = "to top" if vertical else "to right"
46
+ return f"linear-gradient({direction}, {', '.join(stops)})"
47
+
48
+
49
+ _ANCHOR_CSS: Dict[str, str] = {
50
+ "top-left": "top:10px; left:10px;",
51
+ "top-right": "top:10px; right:10px;",
52
+ "bottom-left": "bottom:10px; left:10px;",
53
+ "bottom-right": "bottom:10px; right:10px;",
54
+ "middle-left": "top:50%; left:10px; transform:translateY(-50%);",
55
+ "middle-right": "top:50%; right:10px; transform:translateY(-50%);",
56
+ "middle-center": "top:50%; left:50%; transform:translate(-50%,-50%);",
57
+ }
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # continuous legend
61
+ # ---------------------------------------------------------------------------
62
+
63
+ def continuous_legend_html_css(
64
+ cmap: Union[str, _cm.Colormap, Sequence[str]],
65
+ vmin: Union[int, float, datetime, date],
66
+ vmax: Union[int, float, datetime, date],
67
+ *,
68
+ ticks: Sequence[Union[int, float, datetime, date]] | None = None,
69
+ label: str | None = None,
70
+ bar_size: tuple[int, int] = (10, 200),
71
+ anchor: str = "top-right",
72
+ container_id: str = "dmp-colorbar",
73
+ ) -> Tuple[str, str]:
74
+ """Return *(html, css)* snippet for a static colour‑bar legend."""
75
+
76
+ # ---------- ticks -----------------------------------------------------
77
+ if ticks is None:
78
+ ticks = [vmin + (vmax - vmin) * i / 4 for i in range(5)] # type: ignore
79
+
80
+ def _fmt(val):
81
+ if isinstance(val, (datetime, date)):
82
+ return val.strftime("%Y")
83
+ sci = max(abs(float(vmin)), abs(float(vmax))) >= 1e5 or 0 < abs(float(vmin)) <= 1e-4
84
+ if sci:
85
+ return f"{val:.1e}"
86
+ return f"{val:.0f}" if float(val).is_integer() else f"{val:.2f}"
87
+
88
+ tick_labels = [_fmt(t) for t in ticks]
89
+
90
+ # relative positions (0% top, 100% bottom) -----------------------------
91
+ def _rel(val):
92
+ if isinstance(val, (datetime, date)):
93
+ rng = (ticks[-1] - ticks[0]).total_seconds() or 1
94
+ return (ticks[-1] - val).total_seconds() / rng * 100
95
+ rng = float(ticks[-1] - ticks[0]) or 1
96
+ return (ticks[-1] - val) / rng * 100
97
+
98
+ # ---------- HTML ------------------------------------------------------
99
+ w, h = bar_size
100
+ html: List[str] = [f'<div id="{container_id}" class="colorbar-container">']
101
+
102
+ if label:
103
+ html.append(
104
+ f' <div class="colorbar-label" style="writing-mode:vertical-rl; transform:rotate(180deg); margin-right:8px;">{label}</div>'
105
+ )
106
+
107
+ html.append(f' <div class="colorbar" style="width:{w}px; height:{h}px; background:{_gradient(cmap)};"></div>')
108
+ html.append(' <div class="colorbar-tick-container">')
109
+
110
+ for pos, lab in zip([_rel(t) for t in ticks], tick_labels):
111
+ html.append(
112
+ f' <div class="colorbar-tick" style="top:{pos:.2f}%;">'
113
+ ' <div class="colorbar-tick-line"></div>'
114
+ f' <div class="colorbar-tick-label">{lab}</div>'
115
+ ' </div>'
116
+ )
117
+
118
+ html.extend([' </div>', '</div>'])
119
+
120
+ # ---------- CSS -------------------------------------------------------
121
+ pos_css = _ANCHOR_CSS.get(anchor, _ANCHOR_CSS["top-right"])
122
+ css = f"""
123
+ #{container_id} {{position:absolute; {pos_css} z-index:100; display:flex; align-items:center; gap:4px; padding:10px;}}
124
+ #{container_id} .colorbar-tick-container {{position:relative; width:40px; height:{h}px;}}
125
+ #{container_id} .colorbar-tick {{position:absolute; display:flex; align-items:center; gap:4px; transform:translateY(-50%); font-size:12px;}}
126
+ #{container_id} .colorbar-tick-line {{width:8px; height:1px; background:#333;}}
127
+ #{container_id} .colorbar-label {{font-size:12px;}}
128
+ """
129
+
130
+ return "\n".join(html), css
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # categorical legend
134
+ # ---------------------------------------------------------------------------
135
+
136
+ def categorical_legend_html_css(
137
+ color_mapping: Dict[str, Colour],
138
+ *,
139
+ title: str | None = None,
140
+ swatch: int = 12,
141
+ anchor: str = "bottom-left",
142
+ container_id: str = "dmp-catlegend",
143
+ rows: bool = True,
144
+ ) -> Tuple[str, str]:
145
+ """Return *(html, css)* for a swatch legend."""
146
+
147
+ html: List[str] = [f'<div id="{container_id}" class="color-legend-container">']
148
+ if title:
149
+ html.append(f' <div class="legend-title">{title}</div>')
150
+ for lbl, col in color_mapping.items():
151
+ html.append(
152
+ ' <div class="legend-item">'
153
+ f' <div class="color-swatch-box" style="background:{_hex(col)};"></div>'
154
+ f' <div class="legend-label">{lbl}</div>'
155
+ ' </div>'
156
+ )
157
+ html.append('</div>')
158
+
159
+ pos_css = _ANCHOR_CSS.get(anchor, _ANCHOR_CSS["bottom-left"])
160
+ css = f"""
161
+ #{container_id} {{position:absolute; {pos_css} z-index:100; display:flex; flex-direction:{'column' if rows else 'row'}; gap:4px; padding:10px;}}
162
+ #{container_id} .legend-title {{font-weight:bold; margin-bottom:4px;}}
163
+ #{container_id} .legend-item {{display:flex; align-items:center; gap:4px;}}
164
+ #{container_id} .color-swatch-box {{width:{swatch}px; height:{swatch}px; border-radius:2px; border:1px solid #555;}}
165
+ #{container_id} .legend-label {{font-size:12px;}}
166
+ """
167
+
168
+ return "\n".join(html), css
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # sample script for quick testing
172
+ # ---------------------------------------------------------------------------
173
+ if __name__ == "__main__":
174
+ # pip install datamapplot matplotlib numpy to run this demo
175
+ import numpy as np
176
+ from matplotlib import cm
177
+ import datamapplot as dmp
178
+
179
+ # dummy data ----------------------------------------------------------
180
+ n = 400
181
+ rng = np.random.default_rng(0)
182
+ coords = rng.normal(size=(n, 2))
183
+ years = rng.integers(1990, 2025, size=n)
184
+
185
+ # quadrant labels -----------------------------------------------------
186
+ quad = np.where(coords[:, 0] >= 0,
187
+ np.where(coords[:, 1] >= 0, "A", "D"),
188
+ np.where(coords[:, 1] >= 0, "B", "C"))
189
+
190
+ # colours -------------------------------------------------------------
191
+ grey = "#bbbbbb"
192
+ cols = np.full(n, grey, dtype=object)
193
+ mask = rng.random(n) < 0.1
194
+ vir = cm.get_cmap("viridis")
195
+ cols[mask] = [to_hex(vir((y - years.min())/(years.max()-years.min()))) for y in years[mask]]
196
+
197
+ # legends -------------------------------------------------------------
198
+ html_bar, css_bar = continuous_legend_html_css(
199
+ vir, years.min(), years.max(), label="Year", anchor="middle-right", ticks=[1990, 2000, 2010, 2020, 2024]
200
+ )
201
+ html_cat, css_cat = categorical_legend_html_css(
202
+ {lbl: col for lbl, col in zip("ABCD", cm.tab10.colors)}, title="Quadrant", anchor="bottom-left"
203
+ )
204
+
205
+ custom_html = html_bar + html_cat
206
+ custom_css = css_bar + css_cat
207
+
208
+ # plot ---------------------------------------------------------------
209
+ plot = dmp.create_interactive_plot(
210
+ coords, quad,
211
+ hover_text=np.arange(n).astype(str),
212
+ marker_color_array=cols,
213
+ custom_html=custom_html,
214
+ custom_css=custom_css,
215
+ )
216
+
217
+ # In Jupyter this shows automatically; otherwise save:
218
+ # with open("demo.html", "w") as f:
219
+ # f.write(str(plot))
220
+
221
+ print("Demo plot generated – view in a notebook or open the saved HTML.")