Upload lib%2Fvisualize.py
Browse files- lib%2Fvisualize.py +146 -0
lib%2Fvisualize.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
import networkx as nx
|
3 |
+
import numpy as np
|
4 |
+
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
|
5 |
+
Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
|
6 |
+
from bokeh.palettes import Spectral4
|
7 |
+
from bokeh.plotting import from_networkx
|
8 |
+
|
9 |
+
def create_graph(entities, relationships):
|
10 |
+
G = nx.Graph()
|
11 |
+
for entity_id, entity_data in entities.items():
|
12 |
+
G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})")
|
13 |
+
|
14 |
+
for source, relation, target in relationships:
|
15 |
+
G.add_edge(source, target, label=relation)
|
16 |
+
|
17 |
+
return G
|
18 |
+
|
19 |
+
def improved_spectral_layout(G, scale=1):
|
20 |
+
pos = nx.spectral_layout(G)
|
21 |
+
# Add some random noise to prevent overlapping
|
22 |
+
pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()}
|
23 |
+
# Scale the layout
|
24 |
+
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
|
25 |
+
return pos
|
26 |
+
|
27 |
+
def create_bokeh_plot(G, layout_type='spring'):
|
28 |
+
plot = Plot(width=600, height=600,
|
29 |
+
x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
|
30 |
+
plot.title.text = "Knowledge Graph Interaction"
|
31 |
+
|
32 |
+
node_hover = HoverTool(tooltips=[("Entity", "@label")])
|
33 |
+
edge_hover = HoverTool(tooltips=[("Relation", "@label")])
|
34 |
+
plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
|
35 |
+
|
36 |
+
# Create layout based on layout_type
|
37 |
+
if layout_type == 'spring':
|
38 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50)
|
39 |
+
elif layout_type == 'fruchterman_reingold':
|
40 |
+
pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
|
41 |
+
elif layout_type == 'circular':
|
42 |
+
pos = nx.circular_layout(G)
|
43 |
+
elif layout_type == 'random':
|
44 |
+
pos = nx.random_layout(G)
|
45 |
+
elif layout_type == 'spectral':
|
46 |
+
pos = improved_spectral_layout(G)
|
47 |
+
elif layout_type == 'shell':
|
48 |
+
pos = nx.shell_layout(G)
|
49 |
+
else:
|
50 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50)
|
51 |
+
|
52 |
+
graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0))
|
53 |
+
|
54 |
+
graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
|
55 |
+
graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
|
56 |
+
graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])
|
57 |
+
|
58 |
+
graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
|
59 |
+
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
|
60 |
+
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)
|
61 |
+
|
62 |
+
graph_renderer.selection_policy = NodesAndLinkedEdges()
|
63 |
+
graph_renderer.inspection_policy = NodesAndLinkedEdges()
|
64 |
+
|
65 |
+
plot.renderers.append(graph_renderer)
|
66 |
+
|
67 |
+
# Add node labels
|
68 |
+
x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
|
69 |
+
node_labels = nx.get_node_attributes(G, 'label')
|
70 |
+
source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
|
71 |
+
labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
|
72 |
+
text_font_size='8pt', background_fill_alpha=0.7)
|
73 |
+
plot.renderers.append(labels)
|
74 |
+
|
75 |
+
# Add edge labels
|
76 |
+
edge_x, edge_y, edge_labels = [], [], []
|
77 |
+
for (start_node, end_node, label) in G.edges(data='label'):
|
78 |
+
start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
|
79 |
+
end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
|
80 |
+
edge_x.append((start_x + end_x) / 2)
|
81 |
+
edge_y.append((start_y + end_y) / 2)
|
82 |
+
edge_labels.append(label)
|
83 |
+
|
84 |
+
edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
|
85 |
+
edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
|
86 |
+
background_fill_color='white', text_font_size='8pt',
|
87 |
+
background_fill_alpha=0.7)
|
88 |
+
plot.renderers.append(edge_labels)
|
89 |
+
|
90 |
+
return plot
|
91 |
+
|
92 |
+
def create_plotly_plot(G, layout_type='spring'):
|
93 |
+
# Create layout based on layout_type
|
94 |
+
if layout_type == 'spring':
|
95 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50)
|
96 |
+
elif layout_type == 'fruchterman_reingold':
|
97 |
+
pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
|
98 |
+
elif layout_type == 'circular':
|
99 |
+
pos = nx.circular_layout(G)
|
100 |
+
elif layout_type == 'random':
|
101 |
+
pos = nx.random_layout(G)
|
102 |
+
elif layout_type == 'spectral':
|
103 |
+
pos = improved_spectral_layout(G)
|
104 |
+
elif layout_type == 'shell':
|
105 |
+
pos = nx.shell_layout(G)
|
106 |
+
else:
|
107 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50)
|
108 |
+
|
109 |
+
edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[])
|
110 |
+
node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text",
|
111 |
+
marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15,
|
112 |
+
colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"),
|
113 |
+
line_width=2),
|
114 |
+
text=[], textposition="top center")
|
115 |
+
|
116 |
+
edge_labels = []
|
117 |
+
|
118 |
+
for edge in G.edges():
|
119 |
+
x0, y0 = pos[edge[0]]
|
120 |
+
x1, y1 = pos[edge[1]]
|
121 |
+
edge_trace["x"] += (x0, x1, None)
|
122 |
+
edge_trace["y"] += (y0, y1, None)
|
123 |
+
|
124 |
+
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
|
125 |
+
edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]],
|
126 |
+
textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8)))
|
127 |
+
|
128 |
+
for node in G.nodes():
|
129 |
+
x, y = pos[node]
|
130 |
+
node_trace["x"] += (x,)
|
131 |
+
node_trace["y"] += (y,)
|
132 |
+
node_trace["text"] += (G.nodes[node]["label"],)
|
133 |
+
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
|
134 |
+
|
135 |
+
fig = go.Figure(data=[edge_trace, node_trace] + edge_labels,
|
136 |
+
layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest",
|
137 |
+
margin=dict(b=20, l=5, r=5, t=40), annotations=[],
|
138 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
139 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
140 |
+
width=800, height=600))
|
141 |
+
|
142 |
+
fig.update_layout(newshape=dict(line_color="#009900"),
|
143 |
+
xaxis=dict(scaleanchor="y", scaleratio=1),
|
144 |
+
yaxis=dict(scaleanchor="x", scaleratio=1))
|
145 |
+
|
146 |
+
return fig
|