m7n commited on
Commit
62c25af
·
1 Parent(s): 68c9731

added edge-bundling

Browse files
Files changed (4) hide show
  1. app.py +82 -30
  2. edgebundling.py +498 -0
  3. network_utils.py +67 -0
  4. requirements.txt +2 -2
app.py CHANGED
@@ -19,6 +19,8 @@ import matplotlib.pyplot as plt
19
  import tqdm
20
  import colormaps
21
  import matplotlib.colors as mcolors
 
 
22
 
23
 
24
  import opinionated # for fonts
@@ -52,6 +54,11 @@ from data_setup import (
52
 
53
  )
54
 
 
 
 
 
 
55
  # Configure OpenAlex
56
  pyalex.config.email = "[email protected]"
57
 
@@ -109,7 +116,7 @@ def create_embeddings(texts_to_embedd):
109
 
110
  def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduction_method,
111
  plot_time_checkbox, locally_approximate_publication_date_checkbox,
112
- download_csv_checkbox, download_png_checkbox, progress=gr.Progress()):
113
  """
114
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
115
 
@@ -146,33 +153,41 @@ def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduc
146
  print('Starting data projection pipeline')
147
  progress(0.1, desc="Starting...")
148
 
149
- # Query OpenAlex
150
- query_start = time.time()
151
- query, params = openalex_url_to_pyalex_query(text_input)
152
-
153
- filename = openalex_url_to_filename(text_input)
154
- print(f"Filename: {filename}")
155
-
156
- query_length = query.count()
157
- print(f'Requesting {query_length} entries...')
158
-
159
  records = []
160
- target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
161
-
162
 
163
- should_break = False
164
- for page in query.paginate(per_page=200,n_max=None):
165
- for record in page:
166
- records.append(record)
167
- progress(0.1 + (0.2 * len(records) / target_size), desc="Getting queried data...")
168
- # print(len(records))
169
- if reduce_sample_checkbox and sample_reduction_method == "First n samples" and len(records) >= target_size:
170
- should_break = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  break
172
- if should_break:
173
- break
174
-
175
- print(f"Query completed in {time.time() - query_start:.2f} seconds")
176
 
177
  # Process records
178
  processing_start = time.time()
@@ -239,6 +254,17 @@ def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduc
239
  extra_data = pd.DataFrame(stacked_df['doi'])
240
  print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
241
 
 
 
 
 
 
 
 
 
 
 
 
242
  # Create and save plot
243
  plot_start = time.time()
244
  progress(0.7, desc="Creating plot...")
@@ -261,6 +287,7 @@ def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduc
261
  point_hover_color='#5e2784',
262
  point_radius_max_pixels=7,
263
  cmap=black_cmap,
 
264
  #color_label_text=False,
265
  font_family="Roboto Condensed",
266
  font_weight=600,
@@ -287,8 +314,10 @@ def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduc
287
  png_file_path = static_dir / f"{filename}.png"
288
 
289
  if download_csv_checkbox:
290
- # Export relevant columns
291
- export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y']]
 
 
292
  export_df.to_csv(csv_file_path, index=False)
293
 
294
  if download_png_checkbox:
@@ -307,15 +336,18 @@ def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduc
307
 
308
  # Get the 30 most common labels
309
  unique_labels, counts = np.unique(combined_labels, return_counts=True)
310
- top_30_labels = set(unique_labels[np.argsort(counts)[-50:]])
311
 
312
  # Replace less common labels with 'Unlabelled'
313
  combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels])
314
-
 
315
  colors_base = ['#536878' for _ in range(len(labels1))]
316
  print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds")
317
 
318
  # Create main plot
 
 
319
  print(sample_to_plot[['x','y']].values)
320
  print(combined_labels)
321
 
@@ -342,6 +374,16 @@ def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduc
342
  )
343
  print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds")
344
 
 
 
 
 
 
 
 
 
 
 
345
  # Time-based visualization
346
  scatter_start = time.time()
347
  if plot_time_checkbox:
@@ -506,6 +548,12 @@ with gr.Blocks(theme=theme, css="""
506
  info="Export a static PNG visualization. This will make things slower!"
507
  )
508
 
 
 
 
 
 
 
509
 
510
 
511
 
@@ -529,6 +577,10 @@ with gr.Blocks(theme=theme, css="""
529
  ## How does it work?
530
 
531
  The base map for this project is developed by randomly downloading 250,000 articles from OpenAlex, then embedding their abstracts using our [fine-tuned](https://huggingface.co/m7n/discipline-tuned_specter_2_024) version of the [specter-2](https://huggingface.co/allenai/specter2_aug2023refresh_base) language model, running these embeddings through [UMAP](https://umap-learn.readthedocs.io/en/latest/) to give us a two-dimensional representation, and displaying that in an interactive window using [datamapplot](https://datamapplot.readthedocs.io/en/latest/index.html). After the data for your query is downloaded from OpenAlex, it then undergoes the exact same process, but the pre-trained UMAP model from earlier is used to project your new data points onto this original map, showing where they would show up if they were included in the original sample. For more details, you can take a look at the method section of this paper: **...**
 
 
 
 
532
 
533
  ## I think I found a mistake in the map.
534
 
@@ -568,7 +620,7 @@ with gr.Blocks(theme=theme, css="""
568
  inputs=[text_input, sample_size_slider, reduce_sample_checkbox,
569
  sample_reduction_method, plot_time_checkbox,
570
  locally_approximate_publication_date_checkbox,
571
- download_csv_checkbox, download_png_checkbox],
572
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
573
  )
574
 
 
19
  import tqdm
20
  import colormaps
21
  import matplotlib.colors as mcolors
22
+ from matplotlib.colors import Normalize
23
+
24
 
25
 
26
  import opinionated # for fonts
 
54
 
55
  )
56
 
57
+ from network_utils import create_citation_graph, draw_citation_graph
58
+
59
+
60
+
61
+
62
  # Configure OpenAlex
63
  pyalex.config.email = "[email protected]"
64
 
 
116
 
117
  def predict(text_input, sample_size_slider, reduce_sample_checkbox, sample_reduction_method,
118
  plot_time_checkbox, locally_approximate_publication_date_checkbox,
119
+ download_csv_checkbox, download_png_checkbox,citation_graph_checkbox, progress=gr.Progress()):
120
  """
121
  Main prediction pipeline that processes OpenAlex queries and creates visualizations.
122
 
 
153
  print('Starting data projection pipeline')
154
  progress(0.1, desc="Starting...")
155
 
156
+ # Split input into multiple URLs if present
157
+ urls = [url.strip() for url in text_input.split(';')]
 
 
 
 
 
 
 
 
158
  records = []
159
+ total_query_length = 0
 
160
 
161
+ # Use first URL for filename
162
+ first_query, first_params = openalex_url_to_pyalex_query(urls[0])
163
+ filename = openalex_url_to_filename(urls[0])
164
+ print(f"Filename: {filename}")
165
+
166
+ # Process each URL
167
+ for i, url in enumerate(urls):
168
+ query, params = openalex_url_to_pyalex_query(url)
169
+ query_length = query.count()
170
+ total_query_length += query_length
171
+ print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
172
+
173
+ target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
174
+ records_per_query = 0
175
+
176
+ should_break = False
177
+ for page in query.paginate(per_page=200, n_max=None):
178
+ for record in page:
179
+ records.append(record)
180
+ records_per_query += 1
181
+ progress(0.1 + (0.2 * len(records) / (total_query_length)),
182
+ desc=f"Getting data from query {i+1}/{len(urls)}...")
183
+
184
+ if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
185
+ should_break = True
186
+ break
187
+ if should_break:
188
  break
189
+
190
+ print(f"Query completed in {time.time() - start_time:.2f} seconds")
 
 
191
 
192
  # Process records
193
  processing_start = time.time()
 
254
  extra_data = pd.DataFrame(stacked_df['doi'])
255
  print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
256
 
257
+ if citation_graph_checkbox:
258
+ citation_graph = create_citation_graph(records_df)
259
+ graph_file_name = f"{filename}_citation_graph.jpg"
260
+ graph_file_path = static_dir / graph_file_name
261
+ draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True,
262
+ min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])])
263
+
264
+
265
+
266
+
267
+
268
  # Create and save plot
269
  plot_start = time.time()
270
  progress(0.7, desc="Creating plot...")
 
287
  point_hover_color='#5e2784',
288
  point_radius_max_pixels=7,
289
  cmap=black_cmap,
290
+ background_image=graph_file_name if citation_graph_checkbox else None,
291
  #color_label_text=False,
292
  font_family="Roboto Condensed",
293
  font_weight=600,
 
314
  png_file_path = static_dir / f"{filename}.png"
315
 
316
  if download_csv_checkbox:
317
+ # Export relevant column
318
+ export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
319
+ export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()]
320
+ export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
321
  export_df.to_csv(csv_file_path, index=False)
322
 
323
  if download_png_checkbox:
 
336
 
337
  # Get the 30 most common labels
338
  unique_labels, counts = np.unique(combined_labels, return_counts=True)
339
+ top_30_labels = set(unique_labels[np.argsort(counts)[-70:]])
340
 
341
  # Replace less common labels with 'Unlabelled'
342
  combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels])
343
+ #combined_labels = np.array(['Unlabelled' for label in combined_labels])
344
+ #if label not in top_30_labels else label
345
  colors_base = ['#536878' for _ in range(len(labels1))]
346
  print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds")
347
 
348
  # Create main plot
349
+ print(labels1)
350
+ print(labels2)
351
  print(sample_to_plot[['x','y']].values)
352
  print(combined_labels)
353
 
 
374
  )
375
  print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds")
376
 
377
+
378
+ if citation_graph_checkbox:
379
+
380
+ # Read and add the graph image
381
+ graph_img = plt.imread(graph_file_path)
382
+ ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])],
383
+ alpha=0.9, aspect='auto')
384
+
385
+
386
+
387
  # Time-based visualization
388
  scatter_start = time.time()
389
  if plot_time_checkbox:
 
548
  info="Export a static PNG visualization. This will make things slower!"
549
  )
550
 
551
+ gr.Markdown("### Citation graph")
552
+ citation_graph_checkbox = gr.Checkbox(
553
+ label="Add Citation Graph",
554
+ value=True,
555
+ info="Adds a citation graph of the sample to the plot."
556
+ )
557
 
558
 
559
 
 
577
  ## How does it work?
578
 
579
  The base map for this project is developed by randomly downloading 250,000 articles from OpenAlex, then embedding their abstracts using our [fine-tuned](https://huggingface.co/m7n/discipline-tuned_specter_2_024) version of the [specter-2](https://huggingface.co/allenai/specter2_aug2023refresh_base) language model, running these embeddings through [UMAP](https://umap-learn.readthedocs.io/en/latest/) to give us a two-dimensional representation, and displaying that in an interactive window using [datamapplot](https://datamapplot.readthedocs.io/en/latest/index.html). After the data for your query is downloaded from OpenAlex, it then undergoes the exact same process, but the pre-trained UMAP model from earlier is used to project your new data points onto this original map, showing where they would show up if they were included in the original sample. For more details, you can take a look at the method section of this paper: **...**
580
+
581
+ ## I want to add multiple queries at once!
582
+
583
+ That can be a good idea, e. g. if your interested in a specific paper, as well as all the papers that cite it. Just add the queries to the query box and separate them with a ";" without any spaces in between!
584
 
585
  ## I think I found a mistake in the map.
586
 
 
620
  inputs=[text_input, sample_size_slider, reduce_sample_checkbox,
621
  sample_reduction_method, plot_time_checkbox,
622
  locally_approximate_publication_date_checkbox,
623
+ download_csv_checkbox, download_png_checkbox,citation_graph_checkbox],
624
  outputs=[html, html_download, csv_download, png_download, cancel_btn]
625
  )
626
 
edgebundling.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import networkx as nx
4
+ from matplotlib.collections import LineCollection
5
+ from itertools import count
6
+ from heapq import heappush, heappop
7
+ from collections import defaultdict
8
+ import time
9
+ import pandas as pd
10
+ from datashader.bundling import hammer_bundle # New import for hammer bundling
11
+
12
+ ###############################################################################
13
+ # Minimal AbstractBundling base class (refactored from .abstractBundling import)
14
+ ###############################################################################
15
+ class AbstractBundling:
16
+ def __init__(self, G: nx.Graph):
17
+ self.G = G
18
+
19
+ def bundle(self):
20
+ raise NotImplementedError("Subclasses should implement 'bundle'.")
21
+
22
+ ###############################################################################
23
+ # Simple SplineC placeholder (refactoring out the nx2ipe dependency)
24
+ ###############################################################################
25
+ class SplineC:
26
+ def __init__(self, points):
27
+ self.points = points
28
+
29
+ ###############################################################################
30
+ # A base SpannerBundling class that SpannerBundlingNoSP depends on
31
+ ###############################################################################
32
+ class SpannerBundling(AbstractBundling):
33
+ """
34
+ S-EPB. Implementation
35
+
36
+ weightFactor: kappa value that sets the bundling strength
37
+ distortion: t value that sets the maximum allowed stretch/distortion
38
+ numWorkers: number of workers that process biconnected components
39
+ """
40
+ def __init__(self, G: nx.Graph, weightFactor=2, distortion=2, numWorkers=1):
41
+ super().__init__(G)
42
+ self.distortion = distortion
43
+ self.weightFactor = weightFactor
44
+ self.mode = "greedy"
45
+ self.name = None
46
+ self.numWorkers = numWorkers
47
+
48
+ @property
49
+ def name(self):
50
+ return f"SEPB_d_{self.distortion}_w_{self.weightFactor}_{self.mode}"
51
+
52
+ @name.setter
53
+ def name(self, value):
54
+ self._name = value
55
+
56
+ def bundle(self):
57
+ # Default does nothing
58
+ return 0.0
59
+
60
+ def process(self, component):
61
+ # Default does nothing
62
+ pass
63
+
64
+ def spanner(self, g, k):
65
+ # Default does nothing
66
+ return None
67
+
68
+ ###############################################################################
69
+ # The requested SpannerBundlingNoSP class
70
+ ###############################################################################
71
+ class SpannerBundlingNoSP(SpannerBundling):
72
+ """
73
+ S-EPB where instead of computing single source shortest paths we reuse
74
+ shortest paths during the spanner construction.
75
+ """
76
+ def __init__(self, G: nx.Graph, weightFactor=2, distortion=2):
77
+ super().__init__(G)
78
+ self.distortion = distortion
79
+ self.weightFactor = weightFactor
80
+ self.mode = "reuse"
81
+
82
+ def bundle(self):
83
+ """
84
+ Executes the bundling process on all biconnected components.
85
+ Returns the total time for bundling.
86
+ """
87
+ t_start = time.process_time()
88
+
89
+ if nx.is_directed(self.G):
90
+ # Convert to undirected for the biconnected components
91
+ GG = self.G.to_undirected(as_view=True)
92
+ components = nx.biconnected_components(GG)
93
+ else:
94
+ components = nx.biconnected_components(self.G)
95
+
96
+ to_process = []
97
+ for nodes in components:
98
+ if len(nodes) > 2:
99
+ subg = self.G.subgraph(nodes).copy()
100
+ to_process.append(subg)
101
+
102
+ # Sort the components from largest to smallest
103
+ to_process = sorted(to_process, key=lambda x: len(x.nodes()), reverse=True)
104
+
105
+ # Process each component
106
+ for comp in to_process:
107
+ self.process(comp)
108
+
109
+ return time.process_time() - t_start
110
+
111
+ def process(self, component):
112
+ """
113
+ Process a component: build a spanner, then for each edge not in
114
+ the spanner, store a 'path' and create a Spline if possible.
115
+ """
116
+ T = self.spanner(component, self.distortion)
117
+
118
+ # Mark edges in T as 'Spanning'
119
+ for u, v, data in T.edges(data=True):
120
+ data["weight"] = np.power(data["dist"], self.weightFactor)
121
+
122
+ for u, v in T.edges():
123
+ self.G[u][v]["Layer"] = "Spanning"
124
+ self.G[u][v]["Stroke"] = "blue"
125
+
126
+ # For edges not in T, build a spline from the stored path
127
+ for u, v, data in component.edges(data=True):
128
+ if T.has_edge(u, v):
129
+ continue
130
+
131
+ path = data.get("path", [])
132
+ if len(path) < 1:
133
+ continue
134
+
135
+ spline_points = []
136
+ current = path[0]
137
+ for nxt in path[1:-1]:
138
+ x = component.nodes[nxt].get("X", component.nodes[nxt].get("x", 0))
139
+ y = component.nodes[nxt].get("Y", component.nodes[nxt].get("y", 0))
140
+ spline_points.append((x, y))
141
+ current = nxt
142
+
143
+ self.G[u][v]["Spline"] = SplineC(spline_points)
144
+ self.G[u][v]["Layer"] = "Bundled"
145
+ self.G[u][v]["Stroke"] = "purple"
146
+
147
+ return
148
+
149
+ def spanner(self, g, k):
150
+ """
151
+ Create a spanner and store the shortest path in edge['path'] when the
152
+ edge is not added to the spanner.
153
+ """
154
+ if nx.is_directed(g):
155
+ spanner = nx.DiGraph()
156
+ else:
157
+ spanner = nx.Graph()
158
+
159
+ edges = sorted(g.edges(data=True), key=lambda t: t[2].get("dist", 1))
160
+
161
+ for u, v, data in edges:
162
+ if u not in spanner.nodes:
163
+ spanner.add_edge(u, v, dist=data["dist"])
164
+ continue
165
+ if v not in spanner.nodes:
166
+ spanner.add_edge(u, v, dist=data["dist"])
167
+ continue
168
+
169
+ pred, pathLength = nx.dijkstra_predecessor_and_distance(
170
+ spanner, u, weight="dist", cutoff=k * data["dist"]
171
+ )
172
+
173
+ # If v is in pathLength, we store the path in data['path']
174
+ if v in pathLength:
175
+ # reconstruct path from v back to u
176
+ path = []
177
+ nxt = v
178
+ while nxt != u:
179
+ path.append(nxt)
180
+ nxt = pred[nxt][0]
181
+ # remove the first node (==v) because we typically want just intermediate
182
+ path = path[1:]
183
+ path.reverse()
184
+
185
+ data["path"] = path
186
+ else:
187
+ spanner.add_edge(u, v, dist=data["dist"])
188
+
189
+ return spanner
190
+
191
+ ###############################################################################
192
+ # Function to plot only the bundled edges (with optional color gradient)
193
+ ###############################################################################
194
+ def plot_bundled_edges_only(G, edge_gradient=False, node_colors=None, ax=None, **plot_kwargs):
195
+ """
196
+ Plots only the edges whose 'Layer' is 'Bundled' (or user-defined).
197
+ Nodes are plotted for reference in black.
198
+
199
+ Parameters:
200
+ G: NetworkX graph
201
+ title: Plot title
202
+ edge_gradient: If True, color edges with gradient
203
+ node_colors: Dictionary of node colors
204
+ ax: Optional matplotlib axis to plot on. If None, creates new figure.
205
+ **plot_kwargs: Additional keyword arguments passed to LineCollection
206
+ """
207
+ # Use provided axis or create new one
208
+ if ax is None:
209
+ plt.figure(figsize=(8, 8))
210
+ ax = plt.gca()
211
+
212
+ # 1. Extract positions
213
+ pos = {}
214
+ for node, data in G.nodes(data=True):
215
+ x = data.get('X', data.get('x', 0))
216
+ y = data.get('Y', data.get('y', 0))
217
+ pos[node] = (x, y)
218
+
219
+ # 2. Assign or retrieve node colors. If your graph doesn't already have
220
+ # some color-coded attribute, you can define them here.
221
+ # For example, let's just fix them to green for demonstration:
222
+ # node_colors = {}
223
+ # for node in G.nodes():
224
+ # node_colors[node] = (0.0, 0.5, 0.0, 1.0) # RGBA
225
+
226
+ # 3. Build up segments (and possibly per-segment colors) for the edges
227
+ def binomial(n, k):
228
+ """Compute the binomial coefficient (n choose k)."""
229
+ coeff = 1
230
+ for i in range(1, k + 1):
231
+ coeff *= (n - i + 1) / i
232
+ return coeff
233
+
234
+ def approxBezier(points, n=50):
235
+ """
236
+ Compute and return n points along a Bezier curve defined by control points.
237
+ """
238
+ X, Y = [], []
239
+ m = len(points) - 1
240
+ binom_vals = [binomial(m, i) for i in range(m + 1)]
241
+ t_values = np.linspace(0, 1, n)
242
+ for t in t_values:
243
+ pX, pY = 0.0, 0.0
244
+ for i, p in enumerate(points):
245
+ coeff = binom_vals[i] * ((1 - t) ** (m - i)) * (t ** i)
246
+ pX += coeff * p[0]
247
+ pY += coeff * p[1]
248
+ X.append(pX)
249
+ Y.append(pY)
250
+ return np.column_stack([X, Y])
251
+
252
+ edge_segments = []
253
+ edge_colors = []
254
+
255
+ for u, v, data in G.edges(data=True):
256
+ if data.get("Layer", None) != "Bundled":
257
+ # Skip edges not marked as bundled
258
+ continue
259
+
260
+ # (a) Gather the control points
261
+ if "Spline" in data and data["Spline"] is not None:
262
+ spline_obj = data["Spline"]
263
+ control_points = list(spline_obj.points)
264
+ # Add the start/end for completeness
265
+ control_points = [pos[u]] + control_points + [pos[v]]
266
+ else:
267
+ # fallback to a straight line
268
+ control_points = [pos[u], pos[v]]
269
+
270
+ # (b) Approximate a curve from these control points
271
+ # We always subdivide if edge_gradient is True.
272
+ # If not gradient-based, only subdivide for an actual curve.
273
+ do_subdivide = edge_gradient or (len(control_points) > 2)
274
+ if do_subdivide:
275
+ curve_points = approxBezier(control_points, n=50)
276
+ else:
277
+ curve_points = np.array(control_points)
278
+
279
+ # (c) If we're using gradient, we break it into small segments, each with a color
280
+ if edge_gradient:
281
+ c_u = np.array(node_colors[u]) # RGBA for source node
282
+ c_v = np.array(node_colors[v]) # RGBA for target node
283
+ num_pts = len(curve_points)
284
+ for i in range(num_pts - 1):
285
+ p0 = curve_points[i]
286
+ p1 = curve_points[i + 1]
287
+ # fraction along the curve
288
+ t = i / max(1, (num_pts - 2))
289
+ seg_color = (1 - t) * c_u + t * c_v # linear interpolation in RGBA
290
+ edge_segments.append([p0, p1])
291
+ edge_colors.append(seg_color)
292
+ else:
293
+ # Single color for the entire edge
294
+ if len(curve_points) > 1:
295
+ edge_segments.append([curve_points[0], curve_points[-1]])
296
+ edge_colors.append((0.5, 0.0, 0.5, 0.9)) # purple RGBA
297
+
298
+ # 4. Plot
299
+ # Remove the plt.figure() call since we're using the provided axis
300
+
301
+ # Set default values for LineCollection
302
+ lc_kwargs = {
303
+ 'linewidths': 1,
304
+ 'alpha': 0.9
305
+ }
306
+
307
+ # If colors weren't explicitly passed and we calculated edge_colors, use them
308
+ if 'colors' not in plot_kwargs and edge_colors:
309
+ lc_kwargs['colors'] = edge_colors
310
+
311
+ # Update with user-provided kwargs
312
+ lc_kwargs.update(plot_kwargs)
313
+
314
+ # Create the LineCollection with all parameters
315
+ lc = LineCollection(edge_segments, **lc_kwargs)
316
+ ax.add_collection(lc)
317
+
318
+ # The nodes in black
319
+ # node_positions = np.array([pos[n] for n in G.nodes()])
320
+ # ax.scatter(node_positions[:, 0], node_positions[:, 1], color="black", s=20, alpha=0.8)
321
+
322
+ # ax.set_aspect('equal')
323
+ # Remove plt.show() since we want to allow further additions to the plot
324
+
325
+ ###############################################################################
326
+ # Convenience function to run SpannerBundlingNoSP on a graph and plot results
327
+ ###############################################################################
328
+ def run_and_plot_spanner_bundling_no_sp(G, weightFactor=2, distortion=2, edge_gradient=False, node_colors=None, ax=None, **plot_kwargs):
329
+ """
330
+ Create an instance of SpannerBundlingNoSP, run .bundle(), and
331
+ plot only the bundled edges. Pass edge_gradient=True to see
332
+ color-gradient edges.
333
+
334
+ Additional keyword arguments are passed to the LineCollection for edge styling.
335
+ """
336
+ bundler = SpannerBundlingNoSP(G, weightFactor=weightFactor, distortion=distortion)
337
+ bundler.bundle()
338
+ plot_bundled_edges_only(G,
339
+ edge_gradient=edge_gradient,
340
+ node_colors=node_colors,
341
+ ax=ax,
342
+ **plot_kwargs)
343
+
344
+ def run_hammer_bundling(G, accuracy=500, advect_iterations=50, batch_size=20000,
345
+ decay=0.01, initial_bandwidth=1.1, iterations=4,
346
+ max_segment_length=0.016, min_segment_length=0.008,
347
+ tension=1.2):
348
+ """
349
+ Run hammer bundling on a NetworkX graph and return the bundled paths.
350
+ """
351
+ # Create nodes DataFrame
352
+ nodes = []
353
+ node_to_index = {}
354
+ for i, (node, attr) in enumerate(G.nodes(data=True)):
355
+ x = attr.get('X', attr.get('x', 0))
356
+ y = attr.get('Y', attr.get('y', 0))
357
+ nodes.append({'node': node, 'x': x, 'y': y})
358
+ node_to_index[node] = i
359
+ nodes_df = pd.DataFrame(nodes)
360
+
361
+ # Create edges DataFrame
362
+ edges = []
363
+ for u, v in G.edges():
364
+ edges.append({'source': node_to_index[u], 'target': node_to_index[v]})
365
+ edges_df = pd.DataFrame(edges)
366
+
367
+ # Apply hammer bundling
368
+ bundled_paths = hammer_bundle(nodes_df, edges_df,
369
+ accuracy=accuracy,
370
+ advect_iterations=advect_iterations,
371
+ batch_size=batch_size,
372
+ decay=decay,
373
+ initial_bandwidth=initial_bandwidth,
374
+ iterations=iterations,
375
+ max_segment_length=max_segment_length,
376
+ min_segment_length=min_segment_length,
377
+ tension=tension)
378
+
379
+ # Convert bundled paths to a format compatible with our plotting function
380
+ paths = []
381
+ current_path = []
382
+ edge_index = 0
383
+
384
+ for _, row in bundled_paths.iterrows():
385
+ if pd.isna(row['x']) or pd.isna(row['y']):
386
+ if current_path:
387
+ # Get source and target nodes for this edge
388
+ source_idx = edges_df.iloc[edge_index]['source']
389
+ target_idx = edges_df.iloc[edge_index]['target']
390
+ source_node = nodes_df.iloc[source_idx]['node']
391
+ target_node = nodes_df.iloc[target_idx]['node']
392
+
393
+ paths.append((source_node, target_node, current_path))
394
+ current_path = []
395
+ edge_index += 1
396
+ else:
397
+ current_path.append((row['x'], row['y']))
398
+
399
+ if current_path: # Handle the last path
400
+ source_idx = edges_df.iloc[edge_index]['source']
401
+ target_idx = edges_df.iloc[edge_index]['target']
402
+ source_node = nodes_df.iloc[source_idx]['node']
403
+ target_node = nodes_df.iloc[target_idx]['node']
404
+ paths.append((source_node, target_node, current_path))
405
+
406
+ return paths
407
+
408
+ def plot_bundled_edges(G, bundled_paths, edge_gradient=False, node_colors=None, ax=None, **plot_kwargs):
409
+ """
410
+ Generic plotting function that works with both bundling methods.
411
+
412
+ Parameters:
413
+ G: NetworkX graph
414
+ bundled_paths: List of (source, target, path_points) tuples
415
+ edge_gradient: If True, color edges with gradient
416
+ node_colors: Dictionary of node colors
417
+ ax: Optional matplotlib axis
418
+ **plot_kwargs: Additional styling arguments
419
+ """
420
+ if ax is None:
421
+ plt.figure(figsize=(8, 8))
422
+ ax = plt.gca()
423
+
424
+ def approxBezier(points, n=50):
425
+ """Compute points along a Bezier curve."""
426
+ points = np.array(points)
427
+ t = np.linspace(0, 1, n)
428
+ return np.array([(1-t)*points[:-1] + t*points[1:] for t in t]).reshape(-1, 2)
429
+
430
+ edge_segments = []
431
+ edge_colors = []
432
+
433
+ for source, target, path_points in bundled_paths:
434
+ points = np.array(path_points)
435
+
436
+ if edge_gradient:
437
+ # Create segments with gradient colors
438
+ c_u = np.array(node_colors[source])
439
+ c_v = np.array(node_colors[target])
440
+ num_pts = len(points)
441
+
442
+ for i in range(num_pts - 1):
443
+ p0, p1 = points[i], points[i + 1]
444
+ t = i / max(1, (num_pts - 2))
445
+ seg_color = (1 - t) * c_u + t * c_v
446
+ edge_segments.append([p0, p1])
447
+ edge_colors.append(seg_color)
448
+ else:
449
+ # Single color for the entire path
450
+ for i in range(len(points) - 1):
451
+ edge_segments.append([points[i], points[i + 1]])
452
+ edge_colors.append((0.5, 0.0, 0.5, 0.9))
453
+
454
+ # Plot edges
455
+ lc_kwargs = {'linewidths': 1, 'alpha': 0.9}
456
+ if edge_colors:
457
+ lc_kwargs['colors'] = edge_colors
458
+ lc_kwargs.update(plot_kwargs)
459
+
460
+ lc = LineCollection(edge_segments, **lc_kwargs)
461
+ ax.add_collection(lc)
462
+ ax.autoscale()
463
+
464
+ def run_and_plot_bundling(G, method='hammer', edge_gradient=False, node_colors=None, ax=None,
465
+ bundling_params=None, **plot_kwargs):
466
+ """
467
+ Unified function to run and plot different bundling methods.
468
+
469
+ Parameters:
470
+ G: NetworkX graph
471
+ method: 'spanner' or 'hammer'
472
+ bundling_params: dict of parameters specific to the bundling method
473
+ Other parameters same as plot_bundled_edges
474
+ """
475
+ bundling_params = bundling_params or {}
476
+
477
+ if method == 'spanner':
478
+ bundler = SpannerBundlingNoSP(G, **bundling_params)
479
+ bundler.bundle()
480
+
481
+ # Extract bundled paths from SpannerBundling format
482
+ bundled_paths = []
483
+ for u, v, data in G.edges(data=True):
484
+ if data.get("Layer") == "Bundled" and "Spline" in data:
485
+ spline_points = data["Spline"].points
486
+ pos_u = (G.nodes[u].get('X', G.nodes[u].get('x', 0)),
487
+ G.nodes[u].get('Y', G.nodes[u].get('y', 0)))
488
+ pos_v = (G.nodes[v].get('X', G.nodes[v].get('x', 0)),
489
+ G.nodes[v].get('Y', G.nodes[v].get('y', 0)))
490
+ path = [pos_u] + list(spline_points) + [pos_v]
491
+ bundled_paths.append((u, v, path))
492
+
493
+ elif method == 'hammer':
494
+ bundled_paths = run_hammer_bundling(G, **bundling_params)
495
+ else:
496
+ raise ValueError(f"Unknown bundling method: {method}")
497
+
498
+ plot_bundled_edges(G, bundled_paths, edge_gradient, node_colors, ax, **plot_kwargs)
network_utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from edgebundling import run_and_plot_bundling
5
+ from matplotlib.colors import Normalize
6
+
7
+ def create_citation_graph(df):
8
+ # Create a directed graph
9
+ G = nx.DiGraph()
10
+
11
+ # Add nodes (papers) to the graph with their positions
12
+ pos = {} # Dictionary to store positions
13
+ for idx, row in df.iterrows():
14
+ G.add_node(
15
+ row['id'],
16
+ X=row['x'],
17
+ Y=row['y'],
18
+ publication_year=row['publication_year'],
19
+ color=row['color']
20
+ )
21
+ pos[row['id']] = (row['x'], row['y'])
22
+
23
+ # Add edges based on references
24
+ for idx, row in df.iterrows():
25
+ source_id = row['id']
26
+ refs = row['referenced_works']
27
+ if isinstance(refs, list):
28
+ references = refs
29
+ elif isinstance(refs, str):
30
+ references = refs.split(', ')
31
+ else:
32
+ references = []
33
+ for ref in references:
34
+ if ref in df['id'].values:
35
+ G.add_edge(source_id, ref)
36
+
37
+ G = G.to_undirected()
38
+ return G
39
+
40
+ def draw_citation_graph(G, bundle_edges=False, path=None, min_max_coordinates=None, node_colors=None):
41
+ pos = {}
42
+ for node in G.nodes():
43
+ pos[node] = (G.nodes[node]['X'], G.nodes[node]['Y'])
44
+ fig, ax = plt.subplots(figsize=(20, 20))
45
+ plt.margins(0, 0) # Remove margins
46
+ if bundle_edges:
47
+ # Turning color into rgb
48
+ node_colors = {node: tuple(int(G.nodes[node]['color'].lstrip('#')[i:i+2], 16)/255 for i in (0, 2, 4)) + (1.0,) for node in G.nodes()}
49
+
50
+ for u, v in G.edges():
51
+ x1, y1 = G.nodes[u]['X'], G.nodes[u]['Y']
52
+ x2, y2 = G.nodes[v]['X'], G.nodes[v]['Y']
53
+ G[u][v]['dist'] = ((x1 - x2)**2 + (y1 - y2)**2)**0.5
54
+
55
+ run_and_plot_bundling(G, method="hammer", ax=ax, edge_gradient=True,
56
+ node_colors=node_colors, linewidths=.8, alpha=.5)
57
+ else:
58
+ nx.draw(G, pos=pos, node_size=0, with_labels=False, edge_color='#f98e31', alpha=0.3)
59
+
60
+ plt.axis('off')
61
+ plt.gca().set_aspect('equal')
62
+ if min_max_coordinates is not None:
63
+ plt.xlim(min_max_coordinates[0], min_max_coordinates[1])
64
+ plt.ylim(min_max_coordinates[2], min_max_coordinates[3])
65
+
66
+ if path is not None:
67
+ plt.savefig(path, bbox_inches='tight', pad_inches=0, dpi=800, transparent=True)
requirements.txt CHANGED
@@ -10,12 +10,12 @@ adapters
10
  torch
11
  tqdm
12
  pyarrow
13
- datamapplot==0.5.0
14
  numba==0.58.1
15
  umap-learn==0.5.7
16
  pynndescent==0.5.12
17
  sentence-transformers==3.3.1
18
- dask[complete]==2023.3.0
19
  datashader>=0.16
20
  opinionated
21
  IPython
 
10
  torch
11
  tqdm
12
  pyarrow
13
+ datamapplot==0.5.1
14
  numba==0.58.1
15
  umap-learn==0.5.7
16
  pynndescent==0.5.12
17
  sentence-transformers==3.3.1
18
+ dask[complete]==2024.4.1
19
  datashader>=0.16
20
  opinionated
21
  IPython