|
import gradio as gr |
|
import networkx as nx |
|
import random |
|
import logging |
|
import ast |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
|
def triplextract(text: str, entity_types: list[str], predicates: list[str]) -> str: |
|
""" |
|
Extracts triples (subject, predicate, object) from the given text. |
|
|
|
Args: |
|
text: The input text. |
|
entity_types: A list of entity types to consider. |
|
predicates: A list of predicates to consider. |
|
|
|
Returns: |
|
A string representation of the extracted triples, or an error message if extraction fails. |
|
""" |
|
logging.debug(f"triplextract called with text: {text}, entity_types: {entity_types}, predicates: {predicates}") |
|
print(f"triplextract input:\ntext: {text}\nentity_types: {entity_types}\npredicates: {predicates}") |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "Global warming" in text: |
|
triples = [ |
|
('Global warming', 'causes', 'climate changes'), |
|
('temperature', 'increased by', '1C'), |
|
('warming', 'caused by', 'human activities'), |
|
('Paris Agreement', 'aims to limit', 'temperature increase') |
|
] |
|
prediction = str(triples) |
|
print(f"triplextract output: {prediction}") |
|
return prediction |
|
else: |
|
prediction = "[]" |
|
print(f"triplextract output: {prediction}") |
|
return "[]" |
|
|
|
|
|
except Exception as e: |
|
error_message = f"Error in triplextract: {str(e)}" |
|
logging.exception(error_message) |
|
return f"Error: {error_message}" |
|
|
|
def parse_triples(triples_str: str) -> tuple[list[str], list[tuple[str, str, str]]]: |
|
""" |
|
Parses the string representation of triples into lists of entities and relationships. |
|
|
|
Args: |
|
triples_str: A string representation of the triples (e.g., "[('Alice', 'knows', 'Bob')]"). |
|
|
|
Returns: |
|
A tuple containing: |
|
- A list of unique entities (strings). |
|
- A list of relationships (tuples of (subject, predicate, object)). |
|
""" |
|
logging.debug(f"parse_triples called with triples_str: {triples_str}") |
|
print(f"parse_triples input: {triples_str}") |
|
try: |
|
|
|
|
|
|
|
|
|
try: |
|
triples_list = ast.literal_eval(triples_str) |
|
except (SyntaxError, ValueError) as e: |
|
error_message = f"Error in parse_triples: Invalid triples string format: {str(e)}" |
|
logging.error(error_message) |
|
return [], [] |
|
except Exception as e: |
|
error_message = f"Unexpected error in parse_triples during literal_eval: {str(e)}" |
|
logging.exception(error_message) |
|
return [], [] |
|
|
|
|
|
entities = set() |
|
relationships = [] |
|
if isinstance(triples_list, list): |
|
for triple in triples_list: |
|
if isinstance(triple, tuple) and len(triple) == 3: |
|
subject, predicate, object_ = triple |
|
entities.add(subject) |
|
entities.add(object_) |
|
relationships.append((subject, predicate, object_)) |
|
else: |
|
logging.warning(f"Invalid triple format: {triple}. Skipping.") |
|
else: |
|
logging.warning(f"Triples list is not a list, but {type(triples_list)}. Returning empty lists.") |
|
return [], [] |
|
|
|
return list(entities), relationships |
|
except Exception as e: |
|
error_message = f"Error in parse_triples: {str(e)}" |
|
logging.exception(error_message) |
|
return [], [] |
|
|
|
import networkx as nx |
|
|
|
def create_graph(entities: list[str], relationships: list[tuple[str, str, str]]) -> nx.Graph: |
|
""" |
|
Creates a networkx graph from the given entities and relationships. |
|
|
|
Args: |
|
entities: A list of entity names (strings). |
|
relationships: A list of tuples representing relationships (subject, predicate, object). |
|
|
|
Returns: |
|
A networkx Graph object. |
|
""" |
|
logging.debug(f"create_graph called with entities: {entities}, relationships: {relationships}") |
|
print(f"create_graph input:\nentities: {entities}\nrelationships: {relationships}") |
|
try: |
|
G = nx.Graph() |
|
G.add_nodes_from(entities) |
|
G.add_edges_from([(subject, object_) for subject, _, object_ in relationships]) |
|
|
|
|
|
for subject, predicate, object_ in relationships: |
|
if G.has_edge(subject, object_): |
|
G[subject][object_]['predicate'] = predicate |
|
else: |
|
logging.warning(f"Edge ({subject}, {object_}) not found in the graph.") |
|
|
|
if not nx.is_graph(G): |
|
logging.error("Error: create_graph did not create a valid networkx graph") |
|
return nx.Graph() |
|
return G |
|
except Exception as e: |
|
error_message = f"Error in create_graph: {str(e)}" |
|
logging.exception(error_message) |
|
return nx.Graph() |
|
|
|
|
|
|
|
from bokeh.plotting import figure, show |
|
from bokeh.models import Circle, MultiLine, EdgesAndLinkedNodes, NodesAndLinkedEdges, StaticLayoutProvider |
|
from bokeh.palettes import Category20 |
|
from bokeh.io import output_notebook |
|
import networkx as nx |
|
|
|
def create_bokeh_plot(graph: nx.Graph, layout_type: str): |
|
"""Creates a Bokeh plot of the given networkx graph.""" |
|
try: |
|
if not nx.is_graph(graph): |
|
logging.error("Error: create_bokeh_plot received an invalid networkx graph") |
|
return None |
|
|
|
if layout_type == 'spring': |
|
pos = nx.spring_layout(graph, seed=42) |
|
elif layout_type == 'fruchterman_reingold': |
|
pos = nx.fruchterman_reingold_layout(graph, seed=42) |
|
elif layout_type == 'circular': |
|
pos = nx.circular_layout(graph) |
|
elif layout_type == 'random': |
|
pos = nx.random_layout(graph, seed=42) |
|
elif layout_type == 'spectral': |
|
pos = nx.spectral_layout(graph) |
|
elif layout_type == 'shell': |
|
pos = nx.shell_layout(graph) |
|
else: |
|
pos = nx.spring_layout(graph, seed=42) |
|
|
|
node_indices = list(graph.nodes()) |
|
|
|
x, y = zip(*[pos[i] for i in node_indices]) |
|
|
|
node_data = dict(index=node_indices, x=x, y=y, name=node_indices, size=[15]*len(node_indices)) |
|
edge_data = dict(start=[pos[u][0] for u, v in graph.edges()], |
|
end=[pos[v][0] for u, v in graph.edges()], |
|
xstart=[pos[u][0] for u, v in graph.edges()], |
|
ystart=[pos[u][1] for u, v in graph.edges()], |
|
xend=[pos[v][0] for u, v in graph.edges()], |
|
yend=[pos[v][1] for u, v in graph.edges()]) |
|
|
|
plot = figure(title="Knowledge Graph", width=600, height=600, |
|
tools="pan,wheel_zoom,box_zoom,reset,save", |
|
x_range=(min(x)-0.1, max(x)+0.1), y_range=(min(y)-0.1, max(y)+0.1)) |
|
|
|
plot.scatter("x", "y", size="size", source=node_data, name="nodes") |
|
|
|
plot.multi_line(xs=[[edge_data['xstart'][i], edge_data['xend'][i]] for i in range(len(graph.edges()))], |
|
ys=[[edge_data['ystart'][i], edge_data['yend'][i]] for i in range(len(graph.edges()))], |
|
color="navy", alpha=0.5) |
|
|
|
return plot |
|
except Exception as e: |
|
error_message = f"Error creating Bokeh plot: {str(e)}" |
|
logging.exception(error_message) |
|
return None |
|
|
|
def create_plotly_plot(graph: nx.Graph, layout_type: str): |
|
"""Creates a Plotly plot of the given networkx graph.""" |
|
try: |
|
if not nx.is_graph(graph): |
|
logging.error("Error: create_plotly_plot received an invalid networkx graph") |
|
return None |
|
import plotly.graph_objects as go |
|
|
|
if layout_type == 'spring': |
|
pos = nx.spring_layout(graph, seed=42) |
|
elif layout_type == 'fruchterman_reingold': |
|
pos = nx.fruchterman_reingold_layout(graph, seed=42) |
|
elif layout_type == 'circular': |
|
pos = nx.circular_layout(graph) |
|
elif layout_type == 'random': |
|
pos = nx.random_layout(graph, seed=42) |
|
elif layout_type == 'spectral': |
|
pos = nx.spectral_layout(graph) |
|
elif layout_type == 'shell': |
|
pos = nx.shell_layout(graph) |
|
else: |
|
pos = nx.spring_layout(graph, seed=42) |
|
|
|
edge_x = [] |
|
edge_y = [] |
|
for edge in graph.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_x.append(x0) |
|
edge_y.append(y0) |
|
edge_x.append(x1) |
|
edge_y.append(y1) |
|
edge_x.append(None) |
|
edge_y.append(None) |
|
|
|
edge_trace = go.Scatter( |
|
x=edge_x, y=edge_y, |
|
line=dict(width=0.5, color='#888'), |
|
hoverinfo='none', |
|
mode='lines') |
|
|
|
node_x = [] |
|
node_y = [] |
|
for node in graph.nodes(): |
|
x, y = pos[node] |
|
node_x.append(x) |
|
node_y.append(y) |
|
|
|
node_trace = go.Scatter( |
|
x=node_x, y=node_y, |
|
mode='markers', |
|
hoverinfo='text', |
|
marker=dict( |
|
showscale=False, |
|
colorscale='YlGnBu', |
|
reversescale=True, |
|
color=[], |
|
size=10, |
|
line_width=2)) |
|
|
|
node_adjacencies = [] |
|
node_text = [] |
|
for node, adjacencies in enumerate(graph.adjacency()): |
|
node_adjacencies.append(len(adjacencies[1])) |
|
node_text.append(f"{adjacencies[0]} (# of connections: {len(adjacencies[1])})") |
|
|
|
node_trace.marker.color = node_adjacencies |
|
node_trace.text = node_text |
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], |
|
layout=go.Layout( |
|
title='Knowledge Graph', |
|
titlefont_size=16, |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20,l=5,r=5,t=40), |
|
annotations=[dict( |
|
text="Replace with your attribution text", |
|
showarrow=False, |
|
xref="paper", yref="paper", |
|
x=0.005, y=-0.002)], |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)) |
|
) |
|
|
|
return fig |
|
|
|
except Exception as e: |
|
error_message = f"Error creating Plotly plot: {str(e)}" |
|
logging.exception(error_message) |
|
return None |
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
@dataclass |
|
class Sample: |
|
text_input: str |
|
entity_types: str |
|
predicates: str |
|
|
|
snippets = { |
|
"Sample 1": Sample( |
|
text_input="Alice knows Bob.", |
|
entity_types="Person", |
|
predicates="knows" |
|
), |
|
"Sample 2": Sample( |
|
text_input="The cat sat on the mat.", |
|
entity_types="Animal, Object", |
|
predicates="sat on" |
|
), |
|
"Sample 3": Sample( |
|
text_input="Global warming is causing significant changes to Earth's climate. The average global \ |
|
temperature has increased by approximately 1C since the pre-industrial era. This warming is \ |
|
primarily caused by human activities, particularly the emission of greenhouse gases like carbon dioxide. \ |
|
The Paris Agreement, signed in 2015, aims to limit global temperature increase to well below 2°C above \ |
|
pre-industrial levels. To achieve this goal, many countries are implementing policies to reduce carbon \ |
|
emissions and transition to renewable energy sources.", |
|
entity_types="Event, Measure, Activity, Agreement", |
|
predicates="causes, increased by, aims to limit" |
|
) |
|
} |
|
|
|
|
|
|
|
WORD_LIMIT = 300 |
|
|
|
def process_text(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str): |
|
print(f"process_text input:\ntext: {text}\nentity_types: {entity_types}\npredicates: {predicates}") |
|
if not text: |
|
return None, None, "Please enter some text." |
|
|
|
words = text.split() |
|
if len(words) > WORD_LIMIT: |
|
return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}" |
|
|
|
entity_types_list = [et.strip() for et in entity_types.split(",") if et.strip()] |
|
predicates_list = [p.strip() for p in predicates.split(",") if p.strip()] |
|
|
|
if not entity_types_list: |
|
return None, None, "Please enter at least one entity type." |
|
if not predicates_list: |
|
return None, None, "Please enter at least one predicate." |
|
|
|
try: |
|
prediction = triplextract(text, entity_types_list, predicates_list) |
|
if prediction and prediction.startswith("Error"): |
|
return None, None, prediction |
|
|
|
entities, relationships = parse_triples(prediction) |
|
|
|
if not entities and not relationships: |
|
return None, None, "No entities or relationships found. Try different text or check your input." |
|
|
|
G = create_graph(entities, relationships) |
|
|
|
if visualization_type == 'Bokeh': |
|
fig = create_bokeh_plot(G, layout_type) |
|
else: |
|
fig = create_plotly_plot(G, layout_type) |
|
|
|
output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}" |
|
return G, fig, output_text |
|
except Exception as e: |
|
error_message = f"Error in process_text: {str(e)}" |
|
logging.exception(error_message) |
|
return None, None, f"An error occurred: {str(e)}" |
|
|
|
def update_graph(G: nx.Graph, layout_type: str, visualization_type: str): |
|
if G is None: |
|
return None, "Please process text first." |
|
|
|
try: |
|
if visualization_type == 'Bokeh': |
|
fig = create_bokeh_plot(G, layout_type) |
|
else: |
|
fig = create_plotly_plot(G, layout_type) |
|
return fig, "" |
|
except Exception as e: |
|
error_message = f"Error in update_graph: {str(e)}" |
|
logging.exception(error_message) |
|
return None, f"An error occurred while updating the graph: {str(e)}" |
|
|
|
def update_inputs(sample_name: str): |
|
sample = snippets[sample_name] |
|
return sample.text_input, sample.entity_types, sample.predicates |
|
|
|
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: |
|
gr.Markdown("# Knowledge Graph Extractor") |
|
|
|
|
|
sample_keys = list(snippets.keys()) |
|
default_sample_name = random.choice(sample_keys) if sample_keys else "" |
|
default_sample = snippets.get(default_sample_name) if default_sample_name else None |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
sample_dropdown = gr.Dropdown(choices=sample_keys, label="Select Sample", value=default_sample_name) |
|
input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input if default_sample else "") |
|
entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types if default_sample else "") |
|
predicates = gr.Textbox(label="Predicates", value=default_sample.predicates if default_sample else "") |
|
layout_type = gr.Dropdown( |
|
choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'], |
|
label="Layout Type", value='spring') |
|
visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh') |
|
process_btn = gr.Button("Process Text") |
|
with gr.Column(scale=2): |
|
output_graph = gr.Plot(label="Knowledge Graph") |
|
error_message = gr.Textbox(label="Textual Output") |
|
|
|
graph_state = gr.State(None) |
|
|
|
def process_and_update(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str): |
|
G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type) |
|
print(f"process_and_update: G = {G}") |
|
return G, fig, output |
|
|
|
def update_graph_wrapper(G: nx.Graph, layout_type: str, visualization_type: str): |
|
print(f"update_graph_wrapper: G = {G}") |
|
if G is not None: |
|
fig, _ = update_graph(G, layout_type, visualization_type) |
|
return fig |
|
|
|
sample_dropdown.change(update_inputs, inputs=[sample_dropdown], outputs=[input_text, entity_types, predicates]) |
|
|
|
process_btn.click(process_and_update, |
|
inputs=[input_text, entity_types, predicates, layout_type, visualization_type], |
|
outputs=[graph_state, output_graph, error_message]) |
|
|
|
layout_type.change(update_graph_wrapper, |
|
inputs=[graph_state, layout_type, visualization_type], |
|
outputs=[output_graph]) |
|
|
|
visualization_type.change(update_graph_wrapper, |
|
inputs=[graph_state, layout_type, visualization_type], |
|
outputs=[output_graph]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |