|
import plotly.graph_objects as go |
|
import networkx as nx |
|
import numpy as np |
|
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges, |
|
Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource) |
|
from bokeh.palettes import Spectral4 |
|
from bokeh.plotting import from_networkx |
|
|
|
def create_graph(entities, relationships): |
|
G = nx.Graph() |
|
for entity_id, entity_data in entities.items(): |
|
G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})") |
|
|
|
for source, relation, target in relationships: |
|
G.add_edge(source, target, label=relation) |
|
|
|
return G |
|
|
|
def improved_spectral_layout(G, scale=1): |
|
pos = nx.spectral_layout(G) |
|
|
|
pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()} |
|
|
|
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()} |
|
return pos |
|
|
|
def create_bokeh_plot(G, layout_type='spring'): |
|
plot = Plot(width=600, height=600, |
|
x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2)) |
|
plot.title.text = "Knowledge Graph Interaction" |
|
|
|
node_hover = HoverTool(tooltips=[("Entity", "@label")]) |
|
edge_hover = HoverTool(tooltips=[("Relation", "@label")]) |
|
plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool()) |
|
|
|
|
|
if layout_type == 'spring': |
|
pos = nx.spring_layout(G, k=0.5, iterations=50) |
|
elif layout_type == 'fruchterman_reingold': |
|
pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50) |
|
elif layout_type == 'circular': |
|
pos = nx.circular_layout(G) |
|
elif layout_type == 'random': |
|
pos = nx.random_layout(G) |
|
elif layout_type == 'spectral': |
|
pos = improved_spectral_layout(G) |
|
elif layout_type == 'shell': |
|
pos = nx.shell_layout(G) |
|
else: |
|
pos = nx.spring_layout(G, k=0.5, iterations=50) |
|
|
|
graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0)) |
|
|
|
graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0]) |
|
graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2]) |
|
graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1]) |
|
|
|
graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3) |
|
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4) |
|
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3) |
|
|
|
graph_renderer.selection_policy = NodesAndLinkedEdges() |
|
graph_renderer.inspection_policy = NodesAndLinkedEdges() |
|
|
|
plot.renderers.append(graph_renderer) |
|
|
|
|
|
x, y = zip(*graph_renderer.layout_provider.graph_layout.values()) |
|
node_labels = nx.get_node_attributes(G, 'label') |
|
source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]}) |
|
labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white', |
|
text_font_size='8pt', background_fill_alpha=0.7) |
|
plot.renderers.append(labels) |
|
|
|
|
|
edge_x, edge_y, edge_labels = [], [], [] |
|
for (start_node, end_node, label) in G.edges(data='label'): |
|
start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node] |
|
end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node] |
|
edge_x.append((start_x + end_x) / 2) |
|
edge_y.append((start_y + end_y) / 2) |
|
edge_labels.append(label) |
|
|
|
edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels}) |
|
edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source, |
|
background_fill_color='white', text_font_size='8pt', |
|
background_fill_alpha=0.7) |
|
plot.renderers.append(edge_labels) |
|
|
|
return plot |
|
|
|
def create_plotly_plot(G, layout_type='spring'): |
|
|
|
if layout_type == 'spring': |
|
pos = nx.spring_layout(G, k=0.5, iterations=50) |
|
elif layout_type == 'fruchterman_reingold': |
|
pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50) |
|
elif layout_type == 'circular': |
|
pos = nx.circular_layout(G) |
|
elif layout_type == 'random': |
|
pos = nx.random_layout(G) |
|
elif layout_type == 'spectral': |
|
pos = improved_spectral_layout(G) |
|
elif layout_type == 'shell': |
|
pos = nx.shell_layout(G) |
|
else: |
|
pos = nx.spring_layout(G, k=0.5, iterations=50) |
|
|
|
edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[]) |
|
node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text", |
|
marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15, |
|
colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"), |
|
line_width=2), |
|
text=[], textposition="top center") |
|
|
|
edge_labels = [] |
|
|
|
for edge in G.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_trace["x"] += (x0, x1, None) |
|
edge_trace["y"] += (y0, y1, None) |
|
|
|
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2 |
|
edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]], |
|
textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8))) |
|
|
|
for node in G.nodes(): |
|
x, y = pos[node] |
|
node_trace["x"] += (x,) |
|
node_trace["y"] += (y,) |
|
node_trace["text"] += (G.nodes[node]["label"],) |
|
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),) |
|
|
|
fig = go.Figure(data=[edge_trace, node_trace] + edge_labels, |
|
layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest", |
|
margin=dict(b=20, l=5, r=5, t=40), annotations=[], |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=800, height=600)) |
|
|
|
fig.update_layout(newshape=dict(line_color="#009900"), |
|
xaxis=dict(scaleanchor="y", scaleratio=1), |
|
yaxis=dict(scaleanchor="x", scaleratio=1)) |
|
|
|
return fig |