Remsky's picture
Swapped to bokeh, changed up samples
b26b502
raw
history blame
6.81 kB
import plotly.graph_objects as go
import networkx as nx
import networkx as nx
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
from bokeh.palettes import Spectral4
from bokeh.plotting import from_networkx
def create_bokeh_plot(entities, relationships):
# Create a NetworkX graph
G = nx.Graph()
for entity_id, entity_data in entities.items():
G.add_node(entity_id, label=f"{entity_data['value']} ({entity_data['type']})")
for source, relation, target in relationships:
G.add_edge(source, target, label=relation)
plot = Plot(width=600, height=600, # Increased size for better visibility
x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
plot.title.text = "Knowledge Graph Interaction"
# Use tooltips to show node and edge labels on hover
node_hover = HoverTool(tooltips=[("Entity", "@label")])
edge_hover = HoverTool(tooltips=[("Relation", "@label")])
plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
graph_renderer = from_networkx(G, nx.spring_layout, scale=1,k=0.5, iterations=50, center=(0, 0))
graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])
graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)
graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()
plot.renderers.append(graph_renderer)
# Add node labels
x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
node_labels = nx.get_node_attributes(G, 'label')
source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
text_font_size='8pt', background_fill_alpha=0.7)
plot.renderers.append(labels)
# Add edge labels
edge_x = []
edge_y = []
edge_labels = []
for (start_node, end_node, label) in G.edges(data='label'):
start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
edge_x.append((start_x + end_x) / 2)
edge_y.append((start_y + end_y) / 2)
edge_labels.append(label)
edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
background_fill_color='white', text_font_size='8pt',
background_fill_alpha=0.7)
plot.renderers.append(edge_labels)
return plot
# def create_bokeh_plot(entities, relationships):
# # Create a NetworkX graph
# G = nx.Graph()
# for entity_id, entity_data in entities.items():
# G.add_node(entity_id, **entity_data)
# for source, relation, target in relationships:
# G.add_edge(source, target)
# # Create a Bokeh plot
# plot = figure(title="Knowledge Graph", x_range=(-1.1,1.1), y_range=(-1.1,1.1),
# width=400, height=400, tools="pan,wheel_zoom,box_zoom,reset")
# # Create graph renderer
# graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0,0))
# # Add graph renderer to plot
# plot.renderers.append(graph_renderer)
# return plot
def create_plotly_plot(entities, relationships):
G = nx.DiGraph() # Use DiGraph for directed edges
for entity_id, entity_data in entities.items():
G.add_node(entity_id, **entity_data)
for source, relation, target in relationships:
G.add_edge(source, target, relation=relation)
pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters
edge_trace = go.Scatter(
x=[],
y=[],
line=dict(width=1, color="#888"),
hoverinfo="text",
mode="lines",
text=[],
)
node_trace = go.Scatter(
x=[],
y=[],
mode="markers+text",
hoverinfo="text",
marker=dict(
showscale=True,
colorscale="Viridis",
reversescale=True,
color=[],
size=15,
colorbar=dict(
thickness=15,
title="Node Connections",
xanchor="left",
titleside="right",
),
line_width=2,
),
text=[],
textposition="top center",
)
edge_labels = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace["x"] += (x0, x1, None)
edge_trace["y"] += (y0, y1, None)
# Calculate midpoint for edge label
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
edge_labels.append(
go.Scatter(
x=[mid_x],
y=[mid_y],
mode="text",
text=[G.edges[edge]["relation"]],
textposition="middle center",
hoverinfo="none",
showlegend=False,
textfont=dict(size=8),
)
)
for node in G.nodes():
x, y = pos[node]
node_trace["x"] += (x,)
node_trace["y"] += (y,)
node_info = f"{entities[node]['value']} ({entities[node]['type']})"
node_trace["text"] += (node_info,)
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
fig = go.Figure(
data=[edge_trace, node_trace] + edge_labels,
layout=go.Layout(
title="Knowledge Graph",
titlefont_size=16,
showlegend=False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
annotations=[],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=800,
height=600,
),
)
# Enable dragging of nodes
fig.update_layout(
newshape=dict(line_color="#009900"),
# Enable zoom
xaxis=dict(
scaleanchor="y",
scaleratio=1,
),
yaxis=dict(
scaleanchor="x",
scaleratio=1,
),
)
return fig