Spaces:
Runtime error
Runtime error
File size: 6,812 Bytes
4289090 b26b502 4289090 b26b502 4289090 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import plotly.graph_objects as go
import networkx as nx
import networkx as nx
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
from bokeh.palettes import Spectral4
from bokeh.plotting import from_networkx
def create_bokeh_plot(entities, relationships):
# Create a NetworkX graph
G = nx.Graph()
for entity_id, entity_data in entities.items():
G.add_node(entity_id, label=f"{entity_data['value']} ({entity_data['type']})")
for source, relation, target in relationships:
G.add_edge(source, target, label=relation)
plot = Plot(width=600, height=600, # Increased size for better visibility
x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
plot.title.text = "Knowledge Graph Interaction"
# Use tooltips to show node and edge labels on hover
node_hover = HoverTool(tooltips=[("Entity", "@label")])
edge_hover = HoverTool(tooltips=[("Relation", "@label")])
plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
graph_renderer = from_networkx(G, nx.spring_layout, scale=1,k=0.5, iterations=50, center=(0, 0))
graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])
graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)
graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()
plot.renderers.append(graph_renderer)
# Add node labels
x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
node_labels = nx.get_node_attributes(G, 'label')
source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
text_font_size='8pt', background_fill_alpha=0.7)
plot.renderers.append(labels)
# Add edge labels
edge_x = []
edge_y = []
edge_labels = []
for (start_node, end_node, label) in G.edges(data='label'):
start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
edge_x.append((start_x + end_x) / 2)
edge_y.append((start_y + end_y) / 2)
edge_labels.append(label)
edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
background_fill_color='white', text_font_size='8pt',
background_fill_alpha=0.7)
plot.renderers.append(edge_labels)
return plot
# def create_bokeh_plot(entities, relationships):
# # Create a NetworkX graph
# G = nx.Graph()
# for entity_id, entity_data in entities.items():
# G.add_node(entity_id, **entity_data)
# for source, relation, target in relationships:
# G.add_edge(source, target)
# # Create a Bokeh plot
# plot = figure(title="Knowledge Graph", x_range=(-1.1,1.1), y_range=(-1.1,1.1),
# width=400, height=400, tools="pan,wheel_zoom,box_zoom,reset")
# # Create graph renderer
# graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0,0))
# # Add graph renderer to plot
# plot.renderers.append(graph_renderer)
# return plot
def create_plotly_plot(entities, relationships):
G = nx.DiGraph() # Use DiGraph for directed edges
for entity_id, entity_data in entities.items():
G.add_node(entity_id, **entity_data)
for source, relation, target in relationships:
G.add_edge(source, target, relation=relation)
pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters
edge_trace = go.Scatter(
x=[],
y=[],
line=dict(width=1, color="#888"),
hoverinfo="text",
mode="lines",
text=[],
)
node_trace = go.Scatter(
x=[],
y=[],
mode="markers+text",
hoverinfo="text",
marker=dict(
showscale=True,
colorscale="Viridis",
reversescale=True,
color=[],
size=15,
colorbar=dict(
thickness=15,
title="Node Connections",
xanchor="left",
titleside="right",
),
line_width=2,
),
text=[],
textposition="top center",
)
edge_labels = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace["x"] += (x0, x1, None)
edge_trace["y"] += (y0, y1, None)
# Calculate midpoint for edge label
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
edge_labels.append(
go.Scatter(
x=[mid_x],
y=[mid_y],
mode="text",
text=[G.edges[edge]["relation"]],
textposition="middle center",
hoverinfo="none",
showlegend=False,
textfont=dict(size=8),
)
)
for node in G.nodes():
x, y = pos[node]
node_trace["x"] += (x,)
node_trace["y"] += (y,)
node_info = f"{entities[node]['value']} ({entities[node]['type']})"
node_trace["text"] += (node_info,)
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
fig = go.Figure(
data=[edge_trace, node_trace] + edge_labels,
layout=go.Layout(
title="Knowledge Graph",
titlefont_size=16,
showlegend=False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
annotations=[],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=800,
height=600,
),
)
# Enable dragging of nodes
fig.update_layout(
newshape=dict(line_color="#009900"),
# Enable zoom
xaxis=dict(
scaleanchor="y",
scaleratio=1,
),
yaxis=dict(
scaleanchor="x",
scaleratio=1,
),
)
return fig |