Anupam251272 commited on
Commit
263a24e
·
verified ·
1 Parent(s): 4781f3e

Upload lib%2Fvisualize.py

Browse files
Files changed (1) hide show
  1. lib%2Fvisualize.py +146 -0
lib%2Fvisualize.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import networkx as nx
3
+ import numpy as np
4
+ from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
5
+ Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
6
+ from bokeh.palettes import Spectral4
7
+ from bokeh.plotting import from_networkx
8
+
9
+ def create_graph(entities, relationships):
10
+ G = nx.Graph()
11
+ for entity_id, entity_data in entities.items():
12
+ G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})")
13
+
14
+ for source, relation, target in relationships:
15
+ G.add_edge(source, target, label=relation)
16
+
17
+ return G
18
+
19
+ def improved_spectral_layout(G, scale=1):
20
+ pos = nx.spectral_layout(G)
21
+ # Add some random noise to prevent overlapping
22
+ pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()}
23
+ # Scale the layout
24
+ pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
25
+ return pos
26
+
27
+ def create_bokeh_plot(G, layout_type='spring'):
28
+ plot = Plot(width=600, height=600,
29
+ x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
30
+ plot.title.text = "Knowledge Graph Interaction"
31
+
32
+ node_hover = HoverTool(tooltips=[("Entity", "@label")])
33
+ edge_hover = HoverTool(tooltips=[("Relation", "@label")])
34
+ plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
35
+
36
+ # Create layout based on layout_type
37
+ if layout_type == 'spring':
38
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
39
+ elif layout_type == 'fruchterman_reingold':
40
+ pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
41
+ elif layout_type == 'circular':
42
+ pos = nx.circular_layout(G)
43
+ elif layout_type == 'random':
44
+ pos = nx.random_layout(G)
45
+ elif layout_type == 'spectral':
46
+ pos = improved_spectral_layout(G)
47
+ elif layout_type == 'shell':
48
+ pos = nx.shell_layout(G)
49
+ else:
50
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
51
+
52
+ graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0))
53
+
54
+ graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
55
+ graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
56
+ graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])
57
+
58
+ graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
59
+ graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
60
+ graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)
61
+
62
+ graph_renderer.selection_policy = NodesAndLinkedEdges()
63
+ graph_renderer.inspection_policy = NodesAndLinkedEdges()
64
+
65
+ plot.renderers.append(graph_renderer)
66
+
67
+ # Add node labels
68
+ x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
69
+ node_labels = nx.get_node_attributes(G, 'label')
70
+ source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
71
+ labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
72
+ text_font_size='8pt', background_fill_alpha=0.7)
73
+ plot.renderers.append(labels)
74
+
75
+ # Add edge labels
76
+ edge_x, edge_y, edge_labels = [], [], []
77
+ for (start_node, end_node, label) in G.edges(data='label'):
78
+ start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
79
+ end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
80
+ edge_x.append((start_x + end_x) / 2)
81
+ edge_y.append((start_y + end_y) / 2)
82
+ edge_labels.append(label)
83
+
84
+ edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
85
+ edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
86
+ background_fill_color='white', text_font_size='8pt',
87
+ background_fill_alpha=0.7)
88
+ plot.renderers.append(edge_labels)
89
+
90
+ return plot
91
+
92
+ def create_plotly_plot(G, layout_type='spring'):
93
+ # Create layout based on layout_type
94
+ if layout_type == 'spring':
95
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
96
+ elif layout_type == 'fruchterman_reingold':
97
+ pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
98
+ elif layout_type == 'circular':
99
+ pos = nx.circular_layout(G)
100
+ elif layout_type == 'random':
101
+ pos = nx.random_layout(G)
102
+ elif layout_type == 'spectral':
103
+ pos = improved_spectral_layout(G)
104
+ elif layout_type == 'shell':
105
+ pos = nx.shell_layout(G)
106
+ else:
107
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
108
+
109
+ edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[])
110
+ node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text",
111
+ marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15,
112
+ colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"),
113
+ line_width=2),
114
+ text=[], textposition="top center")
115
+
116
+ edge_labels = []
117
+
118
+ for edge in G.edges():
119
+ x0, y0 = pos[edge[0]]
120
+ x1, y1 = pos[edge[1]]
121
+ edge_trace["x"] += (x0, x1, None)
122
+ edge_trace["y"] += (y0, y1, None)
123
+
124
+ mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
125
+ edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]],
126
+ textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8)))
127
+
128
+ for node in G.nodes():
129
+ x, y = pos[node]
130
+ node_trace["x"] += (x,)
131
+ node_trace["y"] += (y,)
132
+ node_trace["text"] += (G.nodes[node]["label"],)
133
+ node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
134
+
135
+ fig = go.Figure(data=[edge_trace, node_trace] + edge_labels,
136
+ layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest",
137
+ margin=dict(b=20, l=5, r=5, t=40), annotations=[],
138
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
139
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
140
+ width=800, height=600))
141
+
142
+ fig.update_layout(newshape=dict(line_color="#009900"),
143
+ xaxis=dict(scaleanchor="y", scaleratio=1),
144
+ yaxis=dict(scaleanchor="x", scaleratio=1))
145
+
146
+ return fig