Anupam251272 commited on
Commit
41cb4a0
·
verified ·
1 Parent(s): 263a24e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -123
app.py CHANGED
@@ -1,129 +1,282 @@
1
- import random
2
-
3
  import gradio as gr
4
  import networkx as nx
 
 
5
 
6
- from lib.graph_extract import triplextract, parse_triples
7
- from lib.visualize import create_graph, create_bokeh_plot, create_plotly_plot
8
- from lib.samples import snippets
9
 
10
- WORD_LIMIT = 300
11
 
12
- def process_text(text, entity_types, predicates, layout_type, visualization_type):
13
- if not text:
14
- return None, None, "Please enter some text."
 
15
 
16
- words = text.split()
17
- if len(words) > WORD_LIMIT:
18
- return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
19
-
20
- entity_types = [et.strip() for et in entity_types.split(",") if et.strip()]
21
- predicates = [p.strip() for p in predicates.split(",") if p.strip()]
22
-
23
- if not entity_types:
24
- return None, None, "Please enter at least one entity type."
25
- if not predicates:
26
- return None, None, "Please enter at least one predicate."
27
 
 
 
 
 
28
  try:
29
- prediction = triplextract(text, entity_types, predicates)
30
- if prediction.startswith("Error"):
31
- return None, None, prediction
32
-
33
- entities, relationships = parse_triples(prediction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- if not entities and not relationships:
36
- return None, None, "No entities or relationships found. Try different text or check your input."
 
 
37
 
38
- G = create_graph(entities, relationships)
39
 
40
- if visualization_type == 'Bokeh':
41
- fig = create_bokeh_plot(G, layout_type)
42
- else:
43
- fig = create_plotly_plot(G, layout_type)
44
 
45
- output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}"
46
- return G, fig, output_text
47
- except Exception as e:
48
- print(f"Error in process_text: {str(e)}")
49
- return None, None, f"An error occurred: {str(e)}"
50
 
51
- def update_graph(G, layout_type, visualization_type):
52
- if G is None:
53
- return None, "Please process text first."
54
-
55
  try:
56
- if visualization_type == 'Bokeh':
57
- fig = create_bokeh_plot(G, layout_type)
58
- else:
59
- fig = create_plotly_plot(G, layout_type)
60
- return fig, ""
 
 
 
 
 
 
 
61
  except Exception as e:
62
- print(f"Error in update_graph: {e}")
63
- return None, f"An error occurred while updating the graph: {str(e)}"
 
64
 
65
- def update_inputs(sample_name):
66
- sample = snippets[sample_name]
67
- return sample.text_input, sample.entity_types, sample.predicates
68
 
69
- with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
70
- gr.Markdown("# Knowledge Graph Extractor")
71
-
72
- default_sample_name = random.choice(list(snippets.keys()))
73
- default_sample = snippets[default_sample_name]
74
-
75
- with gr.Row():
76
- with gr.Column(scale=1):
77
- sample_dropdown = gr.Dropdown(choices=list(snippets.keys()), label="Select Sample", value=default_sample_name)
78
- input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input)
79
- entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types)
80
- predicates = gr.Textbox(label="Predicates", value=default_sample.predicates)
81
- layout_type = gr.Dropdown(choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'],
82
- label="Layout Type", value='spring')
83
- visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh')
84
- process_btn = gr.Button("Process Text")
85
- with gr.Column(scale=2):
86
- output_graph = gr.Plot(label="Knowledge Graph")
87
- error_message = gr.Textbox(label="Textual Output")
88
 
89
- graph_state = gr.State(None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- def process_and_update(text, entity_types, predicates, layout_type, visualization_type):
92
- G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type)
93
- return G, fig, output
94
 
95
- def update_graph_wrapper(G, layout_type, visualization_type):
96
- if G is not None:
97
- fig, _ = update_graph(G, layout_type, visualization_type)
98
- return fig
99
 
100
- sample_dropdown.change(update_inputs, inputs=[sample_dropdown], outputs=[input_text, entity_types, predicates])
 
 
 
 
 
 
101
 
102
- process_btn.click(process_and_update,
103
- inputs=[input_text, entity_types, predicates, layout_type, visualization_type],
104
- outputs=[graph_state, output_graph, error_message])
105
-
106
- layout_type.change(update_graph_wrapper,
107
- inputs=[graph_state, layout_type, visualization_type],
108
- outputs=[output_graph])
109
-
110
- visualization_type.change(update_graph_wrapper,
111
- inputs=[graph_state, layout_type, visualization_type],
112
- outputs=[output_graph])
113
 
114
- if __name__ == "__main__":
115
- demo.launch(share=True)import random
116
 
117
- import gradio as gr
118
- import networkx as nx
 
119
 
120
- from lib.graph_extract import triplextract, parse_triples
121
- from lib.visualize import create_graph, create_bokeh_plot, create_plotly_plot
122
- from lib.samples import snippets
 
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  WORD_LIMIT = 300
125
 
126
- def process_text(text, entity_types, predicates, layout_type, visualization_type):
127
  if not text:
128
  return None, None, "Please enter some text."
129
 
@@ -131,17 +284,17 @@ def process_text(text, entity_types, predicates, layout_type, visualization_type
131
  if len(words) > WORD_LIMIT:
132
  return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
133
 
134
- entity_types = [et.strip() for et in entity_types.split(",") if et.strip()]
135
- predicates = [p.strip() for p in predicates.split(",") if p.strip()]
136
 
137
- if not entity_types:
138
  return None, None, "Please enter at least one entity type."
139
- if not predicates:
140
  return None, None, "Please enter at least one predicate."
141
 
142
  try:
143
- prediction = triplextract(text, entity_types, predicates)
144
- if prediction.startswith("Error"):
145
  return None, None, prediction
146
 
147
  entities, relationships = parse_triples(prediction)
@@ -159,13 +312,14 @@ def process_text(text, entity_types, predicates, layout_type, visualization_type
159
  output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}"
160
  return G, fig, output_text
161
  except Exception as e:
162
- print(f"Error in process_text: {str(e)}")
 
163
  return None, None, f"An error occurred: {str(e)}"
164
 
165
- def update_graph(G, layout_type, visualization_type):
166
  if G is None:
167
  return None, "Please process text first."
168
-
169
  try:
170
  if visualization_type == 'Bokeh':
171
  fig = create_bokeh_plot(G, layout_type)
@@ -173,27 +327,31 @@ def update_graph(G, layout_type, visualization_type):
173
  fig = create_plotly_plot(G, layout_type)
174
  return fig, ""
175
  except Exception as e:
176
- print(f"Error in update_graph: {e}")
 
177
  return None, f"An error occurred while updating the graph: {str(e)}"
178
 
179
- def update_inputs(sample_name):
180
  sample = snippets[sample_name]
181
  return sample.text_input, sample.entity_types, sample.predicates
182
 
183
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
184
  gr.Markdown("# Knowledge Graph Extractor")
185
-
186
- default_sample_name = random.choice(list(snippets.keys()))
187
- default_sample = snippets[default_sample_name]
188
-
 
 
189
  with gr.Row():
190
  with gr.Column(scale=1):
191
- sample_dropdown = gr.Dropdown(choices=list(snippets.keys()), label="Select Sample", value=default_sample_name)
192
- input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input)
193
- entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types)
194
- predicates = gr.Textbox(label="Predicates", value=default_sample.predicates)
195
- layout_type = gr.Dropdown(choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'],
196
- label="Layout Type", value='spring')
 
197
  visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh')
198
  process_btn = gr.Button("Process Text")
199
  with gr.Column(scale=2):
@@ -202,11 +360,11 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
202
 
203
  graph_state = gr.State(None)
204
 
205
- def process_and_update(text, entity_types, predicates, layout_type, visualization_type):
206
  G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type)
207
  return G, fig, output
208
 
209
- def update_graph_wrapper(G, layout_type, visualization_type):
210
  if G is not None:
211
  fig, _ = update_graph(G, layout_type, visualization_type)
212
  return fig
@@ -216,11 +374,11 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
216
  process_btn.click(process_and_update,
217
  inputs=[input_text, entity_types, predicates, layout_type, visualization_type],
218
  outputs=[graph_state, output_graph, error_message])
219
-
220
  layout_type.change(update_graph_wrapper,
221
  inputs=[graph_state, layout_type, visualization_type],
222
  outputs=[output_graph])
223
-
224
  visualization_type.change(update_graph_wrapper,
225
  inputs=[graph_state, layout_type, visualization_type],
226
  outputs=[output_graph])
 
 
 
1
  import gradio as gr
2
  import networkx as nx
3
+ import random
4
+ import logging
5
 
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
 
8
 
9
+ # --- lib directory code (graph_extract.py, visualize.py, samples.py would go here) ---
10
 
11
+ # Placeholder for your NLP pipeline (replace with your actual implementation)
12
+ def triplextract(text: str, entity_types: list[str], predicates: list[str]) -> str:
13
+ """
14
+ Extracts triples (subject, predicate, object) from the given text.
15
 
16
+ Args:
17
+ text: The input text.
18
+ entity_types: A list of entity types to consider.
19
+ predicates: A list of predicates to consider.
 
 
 
 
 
 
 
20
 
21
+ Returns:
22
+ A string representation of the extracted triples, or an error message if extraction fails.
23
+ """
24
+ logging.debug(f"triplextract called with text: {text}, entity_types: {entity_types}, predicates: {predicates}")
25
  try:
26
+ # Replace this with your actual NLP pipeline logic
27
+ # This is a placeholder for demonstration purposes
28
+ # Example: "Alice knows Bob" -> ("Alice", "knows", "Bob")
29
+ if "Alice knows Bob" in text:
30
+ return "[('Alice', 'knows', 'Bob')]" # Example triple
31
+ elif "The cat sat on the mat" in text:
32
+ return "[('cat', 'sat on', 'mat')]"
33
+ else:
34
+ return "[]" # No triples found (important to return an empty list as a string)
35
+ except Exception as e:
36
+ error_message = f"Error in triplextract: {str(e)}"
37
+ logging.exception(error_message) # Log the full exception with traceback
38
+ return f"Error: {error_message}" # Return an error message as a string
39
+
40
+ def parse_triples(triples_str: str) -> tuple[list[str], list[tuple[str, str, str]]]:
41
+ """
42
+ Parses the string representation of triples into lists of entities and relationships.
43
+
44
+ Args:
45
+ triples_str: A string representation of the triples (e.g., "[('Alice', 'knows', 'Bob')]").
46
+
47
+ Returns:
48
+ A tuple containing:
49
+ - A list of unique entities (strings).
50
+ - A list of relationships (tuples of (subject, predicate, object)).
51
+ """
52
+ logging.debug(f"parse_triples called with triples_str: {triples_str}")
53
+ try:
54
+ # Replace this with your actual parsing logic based on triplextract's output
55
+ # This is a placeholder for demonstration purposes
56
+ import ast
57
+ triples_list = ast.literal_eval(triples_str) # Safely evaluate the string as a list
58
+
59
+ entities = set()
60
+ relationships = []
61
+ for triple in triples_list:
62
+ subject, predicate, object_ = triple # Unpack the triple
63
+ entities.add(subject)
64
+ entities.add(object_)
65
+ relationships.append((subject, predicate, object_))
66
+
67
+ return list(entities), relationships
68
+ except (SyntaxError, ValueError) as e:
69
+ error_message = f"Error in parse_triples: Invalid triples string format: {str(e)}"
70
+ logging.error(error_message)
71
+ return [], [] # Return empty lists to prevent further errors
72
 
73
+ except Exception as e:
74
+ error_message = f"Error in parse_triples: {str(e)}"
75
+ logging.exception(error_message)
76
+ return [], [] # Return empty lists in case of any error
77
 
78
+ import networkx as nx
79
 
80
+ def create_graph(entities: list[str], relationships: list[tuple[str, str, str]]) -> nx.Graph:
81
+ """
82
+ Creates a networkx graph from the given entities and relationships.
 
83
 
84
+ Args:
85
+ entities: A list of entity names (strings).
86
+ relationships: A list of tuples representing relationships (subject, predicate, object).
 
 
87
 
88
+ Returns:
89
+ A networkx Graph object.
90
+ """
91
+ logging.debug(f"create_graph called with entities: {entities}, relationships: {relationships}")
92
  try:
93
+ G = nx.Graph()
94
+ G.add_nodes_from(entities)
95
+ G.add_edges_from([(subject, object_) for subject, _, object_ in relationships]) # Add edges
96
+
97
+ # Add edge attributes to store predicates
98
+ for subject, predicate, object_ in relationships:
99
+ if G.has_edge(subject, object_):
100
+ G[subject][object_]['predicate'] = predicate # Store the predicate as an attribute
101
+ else:
102
+ logging.warning(f"Edge ({subject}, {object_}) not found in the graph.")
103
+
104
+ return G
105
  except Exception as e:
106
+ error_message = f"Error in create_graph: {str(e)}"
107
+ logging.exception(error_message)
108
+ return nx.Graph() # Return an empty graph
109
 
 
 
 
110
 
111
+ # visualize.py (Implement your Bokeh and Plotly visualizations here)
112
+ from bokeh.plotting import figure, show
113
+ from bokeh.models import Circle, MultiLine, EdgesAndLinkedNodes, NodesAndLinkedEdges, StaticLayoutProvider
114
+ from bokeh.palettes import Category20
115
+ from bokeh.io import output_notebook
116
+ import networkx as nx
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ def create_bokeh_plot(graph: nx.Graph, layout_type: str):
119
+ """Creates a Bokeh plot of the given networkx graph."""
120
+ try:
121
+ if layout_type == 'spring':
122
+ pos = nx.spring_layout(graph, seed=42)
123
+ elif layout_type == 'fruchterman_reingold':
124
+ pos = nx.fruchterman_reingold_layout(graph, seed=42)
125
+ elif layout_type == 'circular':
126
+ pos = nx.circular_layout(graph)
127
+ elif layout_type == 'random':
128
+ pos = nx.random_layout(graph, seed=42)
129
+ elif layout_type == 'spectral':
130
+ pos = nx.spectral_layout(graph)
131
+ elif layout_type == 'shell':
132
+ pos = nx.shell_layout(graph)
133
+ else:
134
+ pos = nx.spring_layout(graph, seed=42) # Default layout
135
 
136
+ node_indices = list(graph.nodes())
 
 
137
 
138
+ x, y = zip(*[pos[i] for i in node_indices])
 
 
 
139
 
140
+ node_data = dict(index=node_indices, x=x, y=y, name=node_indices, size=[15]*len(node_indices))
141
+ edge_data = dict(start=[pos[u][0] for u, v in graph.edges()],
142
+ end=[pos[v][0] for u, v in graph.edges()],
143
+ xstart=[pos[u][0] for u, v in graph.edges()],
144
+ ystart=[pos[u][1] for u, v in graph.edges()],
145
+ xend=[pos[v][0] for u, v in graph.edges()],
146
+ yend=[pos[v][1] for u, v in graph.edges()])
147
 
148
+ plot = figure(title="Knowledge Graph", width=600, height=600,
149
+ tools="pan,wheel_zoom,box_zoom,reset,save",
150
+ x_range=(min(x)-0.1, max(x)+0.1), y_range=(min(y)-0.1, max(y)+0.1))
 
 
 
 
 
 
 
 
151
 
152
+ plot.scatter("x", "y", size="size", source=node_data, name="nodes")
 
153
 
154
+ plot.multi_line(xs=[[edge_data['xstart'][i], edge_data['xend'][i]] for i in range(len(graph.edges()))],
155
+ ys=[[edge_data['ystart'][i], edge_data['yend'][i]] for i in range(len(graph.edges()))],
156
+ color="navy", alpha=0.5)
157
 
158
+ return plot
159
+ except Exception as e:
160
+ error_message = f"Error creating Bokeh plot: {str(e)}"
161
+ logging.exception(error_message)
162
+ return None # Or return a placeholder plot
163
 
164
+ def create_plotly_plot(graph: nx.Graph, layout_type: str):
165
+ """Creates a Plotly plot of the given networkx graph."""
166
+ try:
167
+ import plotly.graph_objects as go
168
+
169
+ if layout_type == 'spring':
170
+ pos = nx.spring_layout(graph, seed=42)
171
+ elif layout_type == 'fruchterman_reingold':
172
+ pos = nx.fruchterman_reingold_layout(graph, seed=42)
173
+ elif layout_type == 'circular':
174
+ pos = nx.circular_layout(graph)
175
+ elif layout_type == 'random':
176
+ pos = nx.random_layout(graph, seed=42)
177
+ elif layout_type == 'spectral':
178
+ pos = nx.spectral_layout(graph)
179
+ elif layout_type == 'shell':
180
+ pos = nx.shell_layout(graph)
181
+ else:
182
+ pos = nx.spring_layout(graph, seed=42) # Default layout
183
+
184
+ edge_x = []
185
+ edge_y = []
186
+ for edge in graph.edges():
187
+ x0, y0 = pos[edge[0]]
188
+ x1, y1 = pos[edge[1]]
189
+ edge_x.append(x0)
190
+ edge_y.append(y0)
191
+ edge_x.append(x1)
192
+ edge_y.append(y1)
193
+ edge_x.append(None)
194
+ edge_y.append(None)
195
+
196
+ edge_trace = go.Scatter(
197
+ x=edge_x, y=edge_y,
198
+ line=dict(width=0.5, color='#888'),
199
+ hoverinfo='none',
200
+ mode='lines')
201
+
202
+ node_x = []
203
+ node_y = []
204
+ for node in graph.nodes():
205
+ x, y = pos[node]
206
+ node_x.append(x)
207
+ node_y.append(y)
208
+
209
+ node_trace = go.Scatter(
210
+ x=node_x, y=node_y,
211
+ mode='markers',
212
+ hoverinfo='text',
213
+ marker=dict(
214
+ showscale=False,
215
+ colorscale='YlGnBu',
216
+ reversescale=True,
217
+ color=[],
218
+ size=10,
219
+ line_width=2))
220
+
221
+ node_adjacencies = []
222
+ node_text = []
223
+ for node, adjacencies in enumerate(graph.adjacency()):
224
+ node_adjacencies.append(len(adjacencies[1]))
225
+ node_text.append(f"{adjacencies[0]} (# of connections: {len(adjacencies[1])})")
226
+
227
+ node_trace.marker.color = node_adjacencies
228
+ node_trace.text = node_text
229
+
230
+ fig = go.Figure(data=[edge_trace, node_trace],
231
+ layout=go.Layout(
232
+ title='Knowledge Graph',
233
+ titlefont_size=16,
234
+ showlegend=False,
235
+ hovermode='closest',
236
+ margin=dict(b=20,l=5,r=5,t=40),
237
+ annotations=[dict(
238
+ text="Replace with your attribution text",
239
+ showarrow=False,
240
+ xref="paper", yref="paper",
241
+ x=0.005, y=-0.002)],
242
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
243
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
244
+ )
245
+
246
+ return fig
247
+
248
+ except Exception as e:
249
+ error_message = f"Error creating Plotly plot: {str(e)}"
250
+ logging.exception(error_message)
251
+ return None # Or return a placeholder plot
252
+
253
+ # samples.py
254
+ from dataclasses import dataclass
255
+
256
+ @dataclass
257
+ class Sample:
258
+ text_input: str
259
+ entity_types: str
260
+ predicates: str
261
+
262
+ snippets = {
263
+ "Sample 1": Sample(
264
+ text_input="Alice knows Bob.",
265
+ entity_types="Person",
266
+ predicates="knows"
267
+ ),
268
+ "Sample 2": Sample(
269
+ text_input="The cat sat on the mat.",
270
+ entity_types="Animal, Object",
271
+ predicates="sat on"
272
+ )
273
+ }
274
+
275
+
276
+ # --- Gradio Interface Code ---
277
  WORD_LIMIT = 300
278
 
279
+ def process_text(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str):
280
  if not text:
281
  return None, None, "Please enter some text."
282
 
 
284
  if len(words) > WORD_LIMIT:
285
  return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
286
 
287
+ entity_types_list = [et.strip() for et in entity_types.split(",") if et.strip()]
288
+ predicates_list = [p.strip() for p in predicates.split(",") if p.strip()]
289
 
290
+ if not entity_types_list:
291
  return None, None, "Please enter at least one entity type."
292
+ if not predicates_list:
293
  return None, None, "Please enter at least one predicate."
294
 
295
  try:
296
+ prediction = triplextract(text, entity_types_list, predicates_list) # Pass lists, not strings
297
+ if prediction and prediction.startswith("Error"): # Check for errors
298
  return None, None, prediction
299
 
300
  entities, relationships = parse_triples(prediction)
 
312
  output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}"
313
  return G, fig, output_text
314
  except Exception as e:
315
+ error_message = f"Error in process_text: {str(e)}"
316
+ logging.exception(error_message)
317
  return None, None, f"An error occurred: {str(e)}"
318
 
319
+ def update_graph(G: nx.Graph, layout_type: str, visualization_type: str):
320
  if G is None:
321
  return None, "Please process text first."
322
+
323
  try:
324
  if visualization_type == 'Bokeh':
325
  fig = create_bokeh_plot(G, layout_type)
 
327
  fig = create_plotly_plot(G, layout_type)
328
  return fig, ""
329
  except Exception as e:
330
+ error_message = f"Error in update_graph: {str(e)}"
331
+ logging.exception(error_message)
332
  return None, f"An error occurred while updating the graph: {str(e)}"
333
 
334
+ def update_inputs(sample_name: str):
335
  sample = snippets[sample_name]
336
  return sample.text_input, sample.entity_types, sample.predicates
337
 
338
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
339
  gr.Markdown("# Knowledge Graph Extractor")
340
+
341
+ # Provide a fallback in case snippets is empty
342
+ sample_keys = list(snippets.keys())
343
+ default_sample_name = random.choice(sample_keys) if sample_keys else ""
344
+ default_sample = snippets.get(default_sample_name) if default_sample_name else None # Safely get the sample
345
+
346
  with gr.Row():
347
  with gr.Column(scale=1):
348
+ sample_dropdown = gr.Dropdown(choices=sample_keys, label="Select Sample", value=default_sample_name)
349
+ input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input if default_sample else "")
350
+ entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types if default_sample else "")
351
+ predicates = gr.Textbox(label="Predicates", value=default_sample.predicates if default_sample else "")
352
+ layout_type = gr.Dropdown(
353
+ choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'],
354
+ label="Layout Type", value='spring')
355
  visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh')
356
  process_btn = gr.Button("Process Text")
357
  with gr.Column(scale=2):
 
360
 
361
  graph_state = gr.State(None)
362
 
363
+ def process_and_update(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str):
364
  G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type)
365
  return G, fig, output
366
 
367
+ def update_graph_wrapper(G: nx.Graph, layout_type: str, visualization_type: str):
368
  if G is not None:
369
  fig, _ = update_graph(G, layout_type, visualization_type)
370
  return fig
 
374
  process_btn.click(process_and_update,
375
  inputs=[input_text, entity_types, predicates, layout_type, visualization_type],
376
  outputs=[graph_state, output_graph, error_message])
377
+
378
  layout_type.change(update_graph_wrapper,
379
  inputs=[graph_state, layout_type, visualization_type],
380
  outputs=[output_graph])
381
+
382
  visualization_type.change(update_graph_wrapper,
383
  inputs=[graph_state, layout_type, visualization_type],
384
  outputs=[output_graph])