Anupam251272's picture
Update app.py
5f86c2f verified
import gradio as gr
import networkx as nx
import random
import logging
import ast
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# --- lib directory code (graph_extract.py, visualize.py, samples.py would go here) ---
# Placeholder for your NLP pipeline (replace with your actual implementation)
def triplextract(text: str, entity_types: list[str], predicates: list[str]) -> str:
"""
Extracts triples (subject, predicate, object) from the given text.
Args:
text: The input text.
entity_types: A list of entity types to consider.
predicates: A list of predicates to consider.
Returns:
A string representation of the extracted triples, or an error message if extraction fails.
"""
logging.debug(f"triplextract called with text: {text}, entity_types: {entity_types}, predicates: {predicates}")
print(f"triplextract input:\ntext: {text}\nentity_types: {entity_types}\npredicates: {predicates}")
try:
# Replace this with your actual NLP pipeline logic
# This is a placeholder for demonstration purposes
# Example: "Alice knows Bob" -> ("Alice", "knows", "Bob")
# TEMPORARY - Hardcoded triple for testing
# if "Global warming" in text:
# prediction = "[('Global warming', 'causes', 'climate changes')]" #hardcoded string
# print(f"triplextract output: {prediction}")
# return prediction
# else:
# prediction = "[]"
# print(f"triplextract output: {prediction}")
# return "[]"
if "Global warming" in text:
triples = [
('Global warming', 'causes', 'climate changes'),
('temperature', 'increased by', '1C'),
('warming', 'caused by', 'human activities'),
('Paris Agreement', 'aims to limit', 'temperature increase')
]
prediction = str(triples)
print(f"triplextract output: {prediction}")
return prediction
else:
prediction = "[]"
print(f"triplextract output: {prediction}")
return "[]"
# ... (Remove this hardcoding once the rest works)
except Exception as e:
error_message = f"Error in triplextract: {str(e)}"
logging.exception(error_message) # Log the full exception with traceback
return f"Error: {error_message}" # Return an error message as a string
def parse_triples(triples_str: str) -> tuple[list[str], list[tuple[str, str, str]]]:
"""
Parses the string representation of triples into lists of entities and relationships.
Args:
triples_str: A string representation of the triples (e.g., "[('Alice', 'knows', 'Bob')]").
Returns:
A tuple containing:
- A list of unique entities (strings).
- A list of relationships (tuples of (subject, predicate, object)).
"""
logging.debug(f"parse_triples called with triples_str: {triples_str}")
print(f"parse_triples input: {triples_str}")
try:
# Replace this with your actual parsing logic based on triplextract's output
# This is a placeholder for demonstration purposes
# Properly handle error cases where ast.literal_eval fails
try:
triples_list = ast.literal_eval(triples_str) # Safely evaluate the string as a list
except (SyntaxError, ValueError) as e:
error_message = f"Error in parse_triples: Invalid triples string format: {str(e)}"
logging.error(error_message)
return [], [] # Return empty lists to prevent further errors
except Exception as e:
error_message = f"Unexpected error in parse_triples during literal_eval: {str(e)}"
logging.exception(error_message)
return [], []
entities = set()
relationships = []
if isinstance(triples_list, list):
for triple in triples_list:
if isinstance(triple, tuple) and len(triple) == 3:
subject, predicate, object_ = triple # Unpack the triple
entities.add(subject)
entities.add(object_)
relationships.append((subject, predicate, object_))
else:
logging.warning(f"Invalid triple format: {triple}. Skipping.")
else:
logging.warning(f"Triples list is not a list, but {type(triples_list)}. Returning empty lists.")
return [], []
return list(entities), relationships
except Exception as e:
error_message = f"Error in parse_triples: {str(e)}"
logging.exception(error_message)
return [], [] # Return empty lists in case of any error
import networkx as nx
def create_graph(entities: list[str], relationships: list[tuple[str, str, str]]) -> nx.Graph:
"""
Creates a networkx graph from the given entities and relationships.
Args:
entities: A list of entity names (strings).
relationships: A list of tuples representing relationships (subject, predicate, object).
Returns:
A networkx Graph object.
"""
logging.debug(f"create_graph called with entities: {entities}, relationships: {relationships}")
print(f"create_graph input:\nentities: {entities}\nrelationships: {relationships}")
try:
G = nx.Graph()
G.add_nodes_from(entities)
G.add_edges_from([(subject, object_) for subject, _, object_ in relationships]) # Add edges
# Add edge attributes to store predicates
for subject, predicate, object_ in relationships:
if G.has_edge(subject, object_):
G[subject][object_]['predicate'] = predicate # Store the predicate as an attribute
else:
logging.warning(f"Edge ({subject}, {object_}) not found in the graph.")
if not nx.is_graph(G): # Add this check
logging.error("Error: create_graph did not create a valid networkx graph")
return nx.Graph() # Return an empty graph
return G
except Exception as e:
error_message = f"Error in create_graph: {str(e)}"
logging.exception(error_message)
return nx.Graph() # Return an empty graph
# visualize.py (Implement your Bokeh and Plotly visualizations here)
from bokeh.plotting import figure, show
from bokeh.models import Circle, MultiLine, EdgesAndLinkedNodes, NodesAndLinkedEdges, StaticLayoutProvider
from bokeh.palettes import Category20
from bokeh.io import output_notebook
import networkx as nx
def create_bokeh_plot(graph: nx.Graph, layout_type: str):
"""Creates a Bokeh plot of the given networkx graph."""
try:
if not nx.is_graph(graph): # Add this check
logging.error("Error: create_bokeh_plot received an invalid networkx graph")
return None # Or return a placeholder plot
if layout_type == 'spring':
pos = nx.spring_layout(graph, seed=42)
elif layout_type == 'fruchterman_reingold':
pos = nx.fruchterman_reingold_layout(graph, seed=42)
elif layout_type == 'circular':
pos = nx.circular_layout(graph)
elif layout_type == 'random':
pos = nx.random_layout(graph, seed=42)
elif layout_type == 'spectral':
pos = nx.spectral_layout(graph)
elif layout_type == 'shell':
pos = nx.shell_layout(graph)
else:
pos = nx.spring_layout(graph, seed=42) # Default layout
node_indices = list(graph.nodes())
x, y = zip(*[pos[i] for i in node_indices])
node_data = dict(index=node_indices, x=x, y=y, name=node_indices, size=[15]*len(node_indices))
edge_data = dict(start=[pos[u][0] for u, v in graph.edges()],
end=[pos[v][0] for u, v in graph.edges()],
xstart=[pos[u][0] for u, v in graph.edges()],
ystart=[pos[u][1] for u, v in graph.edges()],
xend=[pos[v][0] for u, v in graph.edges()],
yend=[pos[v][1] for u, v in graph.edges()])
plot = figure(title="Knowledge Graph", width=600, height=600,
tools="pan,wheel_zoom,box_zoom,reset,save",
x_range=(min(x)-0.1, max(x)+0.1), y_range=(min(y)-0.1, max(y)+0.1))
plot.scatter("x", "y", size="size", source=node_data, name="nodes")
plot.multi_line(xs=[[edge_data['xstart'][i], edge_data['xend'][i]] for i in range(len(graph.edges()))],
ys=[[edge_data['ystart'][i], edge_data['yend'][i]] for i in range(len(graph.edges()))],
color="navy", alpha=0.5)
return plot
except Exception as e:
error_message = f"Error creating Bokeh plot: {str(e)}"
logging.exception(error_message)
return None # Or return a placeholder plot
def create_plotly_plot(graph: nx.Graph, layout_type: str):
"""Creates a Plotly plot of the given networkx graph."""
try:
if not nx.is_graph(graph): # Add this check
logging.error("Error: create_plotly_plot received an invalid networkx graph")
return None # Or return a placeholder plot
import plotly.graph_objects as go
if layout_type == 'spring':
pos = nx.spring_layout(graph, seed=42)
elif layout_type == 'fruchterman_reingold':
pos = nx.fruchterman_reingold_layout(graph, seed=42)
elif layout_type == 'circular':
pos = nx.circular_layout(graph)
elif layout_type == 'random':
pos = nx.random_layout(graph, seed=42)
elif layout_type == 'spectral':
pos = nx.spectral_layout(graph)
elif layout_type == 'shell':
pos = nx.shell_layout(graph)
else:
pos = nx.spring_layout(graph, seed=42) # Default layout
edge_x = []
edge_y = []
for edge in graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.append(x0)
edge_y.append(y0)
edge_x.append(x1)
edge_y.append(y1)
edge_x.append(None)
edge_y.append(None)
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines')
node_x = []
node_y = []
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers',
hoverinfo='text',
marker=dict(
showscale=False,
colorscale='YlGnBu',
reversescale=True,
color=[],
size=10,
line_width=2))
node_adjacencies = []
node_text = []
for node, adjacencies in enumerate(graph.adjacency()):
node_adjacencies.append(len(adjacencies[1]))
node_text.append(f"{adjacencies[0]} (# of connections: {len(adjacencies[1])})")
node_trace.marker.color = node_adjacencies
node_trace.text = node_text
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title='Knowledge Graph',
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
annotations=[dict(
text="Replace with your attribution text",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002)],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
)
return fig
except Exception as e:
error_message = f"Error creating Plotly plot: {str(e)}"
logging.exception(error_message)
return None # Or return a placeholder plot
# samples.py
from dataclasses import dataclass
@dataclass
class Sample:
text_input: str
entity_types: str
predicates: str
snippets = {
"Sample 1": Sample(
text_input="Alice knows Bob.",
entity_types="Person",
predicates="knows"
),
"Sample 2": Sample(
text_input="The cat sat on the mat.",
entity_types="Animal, Object",
predicates="sat on"
),
"Sample 3": Sample(
text_input="Global warming is causing significant changes to Earth's climate. The average global \
temperature has increased by approximately 1C since the pre-industrial era. This warming is \
primarily caused by human activities, particularly the emission of greenhouse gases like carbon dioxide. \
The Paris Agreement, signed in 2015, aims to limit global temperature increase to well below 2°C above \
pre-industrial levels. To achieve this goal, many countries are implementing policies to reduce carbon \
emissions and transition to renewable energy sources.",
entity_types="Event, Measure, Activity, Agreement",
predicates="causes, increased by, aims to limit"
)
}
# --- Gradio Interface Code ---
WORD_LIMIT = 300
def process_text(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str):
print(f"process_text input:\ntext: {text}\nentity_types: {entity_types}\npredicates: {predicates}")
if not text:
return None, None, "Please enter some text."
words = text.split()
if len(words) > WORD_LIMIT:
return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
entity_types_list = [et.strip() for et in entity_types.split(",") if et.strip()]
predicates_list = [p.strip() for p in predicates.split(",") if p.strip()]
if not entity_types_list:
return None, None, "Please enter at least one entity type."
if not predicates_list:
return None, None, "Please enter at least one predicate."
try:
prediction = triplextract(text, entity_types_list, predicates_list) # Pass lists, not strings
if prediction and prediction.startswith("Error"): # Check for errors
return None, None, prediction
entities, relationships = parse_triples(prediction)
if not entities and not relationships:
return None, None, "No entities or relationships found. Try different text or check your input."
G = create_graph(entities, relationships)
if visualization_type == 'Bokeh':
fig = create_bokeh_plot(G, layout_type)
else:
fig = create_plotly_plot(G, layout_type)
output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}"
return G, fig, output_text
except Exception as e:
error_message = f"Error in process_text: {str(e)}"
logging.exception(error_message)
return None, None, f"An error occurred: {str(e)}"
def update_graph(G: nx.Graph, layout_type: str, visualization_type: str):
if G is None:
return None, "Please process text first."
try:
if visualization_type == 'Bokeh':
fig = create_bokeh_plot(G, layout_type)
else:
fig = create_plotly_plot(G, layout_type)
return fig, ""
except Exception as e:
error_message = f"Error in update_graph: {str(e)}"
logging.exception(error_message)
return None, f"An error occurred while updating the graph: {str(e)}"
def update_inputs(sample_name: str):
sample = snippets[sample_name]
return sample.text_input, sample.entity_types, sample.predicates
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# Knowledge Graph Extractor")
# Provide a fallback in case snippets is empty
sample_keys = list(snippets.keys())
default_sample_name = random.choice(sample_keys) if sample_keys else ""
default_sample = snippets.get(default_sample_name) if default_sample_name else None # Safely get the sample
with gr.Row():
with gr.Column(scale=1):
sample_dropdown = gr.Dropdown(choices=sample_keys, label="Select Sample", value=default_sample_name)
input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input if default_sample else "")
entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types if default_sample else "")
predicates = gr.Textbox(label="Predicates", value=default_sample.predicates if default_sample else "")
layout_type = gr.Dropdown(
choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'],
label="Layout Type", value='spring')
visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh')
process_btn = gr.Button("Process Text")
with gr.Column(scale=2):
output_graph = gr.Plot(label="Knowledge Graph")
error_message = gr.Textbox(label="Textual Output")
graph_state = gr.State(None)
def process_and_update(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str):
G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type)
print(f"process_and_update: G = {G}") # Debug
return G, fig, output
def update_graph_wrapper(G: nx.Graph, layout_type: str, visualization_type: str):
print(f"update_graph_wrapper: G = {G}") # Debug
if G is not None:
fig, _ = update_graph(G, layout_type, visualization_type)
return fig
sample_dropdown.change(update_inputs, inputs=[sample_dropdown], outputs=[input_text, entity_types, predicates])
process_btn.click(process_and_update,
inputs=[input_text, entity_types, predicates, layout_type, visualization_type],
outputs=[graph_state, output_graph, error_message])
layout_type.change(update_graph_wrapper,
inputs=[graph_state, layout_type, visualization_type],
outputs=[output_graph])
visualization_type.change(update_graph_wrapper,
inputs=[graph_state, layout_type, visualization_type],
outputs=[output_graph])
if __name__ == "__main__":
demo.launch(share=True)