import gradio as gr import networkx as nx import random import logging import ast # Configure logging logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') # --- lib directory code (graph_extract.py, visualize.py, samples.py would go here) --- # Placeholder for your NLP pipeline (replace with your actual implementation) def triplextract(text: str, entity_types: list[str], predicates: list[str]) -> str: """ Extracts triples (subject, predicate, object) from the given text. Args: text: The input text. entity_types: A list of entity types to consider. predicates: A list of predicates to consider. Returns: A string representation of the extracted triples, or an error message if extraction fails. """ logging.debug(f"triplextract called with text: {text}, entity_types: {entity_types}, predicates: {predicates}") print(f"triplextract input:\ntext: {text}\nentity_types: {entity_types}\npredicates: {predicates}") try: # Replace this with your actual NLP pipeline logic # This is a placeholder for demonstration purposes # Example: "Alice knows Bob" -> ("Alice", "knows", "Bob") # TEMPORARY - Hardcoded triple for testing # if "Global warming" in text: # prediction = "[('Global warming', 'causes', 'climate changes')]" #hardcoded string # print(f"triplextract output: {prediction}") # return prediction # else: # prediction = "[]" # print(f"triplextract output: {prediction}") # return "[]" if "Global warming" in text: triples = [ ('Global warming', 'causes', 'climate changes'), ('temperature', 'increased by', '1C'), ('warming', 'caused by', 'human activities'), ('Paris Agreement', 'aims to limit', 'temperature increase') ] prediction = str(triples) print(f"triplextract output: {prediction}") return prediction else: prediction = "[]" print(f"triplextract output: {prediction}") return "[]" # ... (Remove this hardcoding once the rest works) except Exception as e: error_message = f"Error in triplextract: {str(e)}" logging.exception(error_message) # Log the full exception with traceback return f"Error: {error_message}" # Return an error message as a string def parse_triples(triples_str: str) -> tuple[list[str], list[tuple[str, str, str]]]: """ Parses the string representation of triples into lists of entities and relationships. Args: triples_str: A string representation of the triples (e.g., "[('Alice', 'knows', 'Bob')]"). Returns: A tuple containing: - A list of unique entities (strings). - A list of relationships (tuples of (subject, predicate, object)). """ logging.debug(f"parse_triples called with triples_str: {triples_str}") print(f"parse_triples input: {triples_str}") try: # Replace this with your actual parsing logic based on triplextract's output # This is a placeholder for demonstration purposes # Properly handle error cases where ast.literal_eval fails try: triples_list = ast.literal_eval(triples_str) # Safely evaluate the string as a list except (SyntaxError, ValueError) as e: error_message = f"Error in parse_triples: Invalid triples string format: {str(e)}" logging.error(error_message) return [], [] # Return empty lists to prevent further errors except Exception as e: error_message = f"Unexpected error in parse_triples during literal_eval: {str(e)}" logging.exception(error_message) return [], [] entities = set() relationships = [] if isinstance(triples_list, list): for triple in triples_list: if isinstance(triple, tuple) and len(triple) == 3: subject, predicate, object_ = triple # Unpack the triple entities.add(subject) entities.add(object_) relationships.append((subject, predicate, object_)) else: logging.warning(f"Invalid triple format: {triple}. Skipping.") else: logging.warning(f"Triples list is not a list, but {type(triples_list)}. Returning empty lists.") return [], [] return list(entities), relationships except Exception as e: error_message = f"Error in parse_triples: {str(e)}" logging.exception(error_message) return [], [] # Return empty lists in case of any error import networkx as nx def create_graph(entities: list[str], relationships: list[tuple[str, str, str]]) -> nx.Graph: """ Creates a networkx graph from the given entities and relationships. Args: entities: A list of entity names (strings). relationships: A list of tuples representing relationships (subject, predicate, object). Returns: A networkx Graph object. """ logging.debug(f"create_graph called with entities: {entities}, relationships: {relationships}") print(f"create_graph input:\nentities: {entities}\nrelationships: {relationships}") try: G = nx.Graph() G.add_nodes_from(entities) G.add_edges_from([(subject, object_) for subject, _, object_ in relationships]) # Add edges # Add edge attributes to store predicates for subject, predicate, object_ in relationships: if G.has_edge(subject, object_): G[subject][object_]['predicate'] = predicate # Store the predicate as an attribute else: logging.warning(f"Edge ({subject}, {object_}) not found in the graph.") if not nx.is_graph(G): # Add this check logging.error("Error: create_graph did not create a valid networkx graph") return nx.Graph() # Return an empty graph return G except Exception as e: error_message = f"Error in create_graph: {str(e)}" logging.exception(error_message) return nx.Graph() # Return an empty graph # visualize.py (Implement your Bokeh and Plotly visualizations here) from bokeh.plotting import figure, show from bokeh.models import Circle, MultiLine, EdgesAndLinkedNodes, NodesAndLinkedEdges, StaticLayoutProvider from bokeh.palettes import Category20 from bokeh.io import output_notebook import networkx as nx def create_bokeh_plot(graph: nx.Graph, layout_type: str): """Creates a Bokeh plot of the given networkx graph.""" try: if not nx.is_graph(graph): # Add this check logging.error("Error: create_bokeh_plot received an invalid networkx graph") return None # Or return a placeholder plot if layout_type == 'spring': pos = nx.spring_layout(graph, seed=42) elif layout_type == 'fruchterman_reingold': pos = nx.fruchterman_reingold_layout(graph, seed=42) elif layout_type == 'circular': pos = nx.circular_layout(graph) elif layout_type == 'random': pos = nx.random_layout(graph, seed=42) elif layout_type == 'spectral': pos = nx.spectral_layout(graph) elif layout_type == 'shell': pos = nx.shell_layout(graph) else: pos = nx.spring_layout(graph, seed=42) # Default layout node_indices = list(graph.nodes()) x, y = zip(*[pos[i] for i in node_indices]) node_data = dict(index=node_indices, x=x, y=y, name=node_indices, size=[15]*len(node_indices)) edge_data = dict(start=[pos[u][0] for u, v in graph.edges()], end=[pos[v][0] for u, v in graph.edges()], xstart=[pos[u][0] for u, v in graph.edges()], ystart=[pos[u][1] for u, v in graph.edges()], xend=[pos[v][0] for u, v in graph.edges()], yend=[pos[v][1] for u, v in graph.edges()]) plot = figure(title="Knowledge Graph", width=600, height=600, tools="pan,wheel_zoom,box_zoom,reset,save", x_range=(min(x)-0.1, max(x)+0.1), y_range=(min(y)-0.1, max(y)+0.1)) plot.scatter("x", "y", size="size", source=node_data, name="nodes") plot.multi_line(xs=[[edge_data['xstart'][i], edge_data['xend'][i]] for i in range(len(graph.edges()))], ys=[[edge_data['ystart'][i], edge_data['yend'][i]] for i in range(len(graph.edges()))], color="navy", alpha=0.5) return plot except Exception as e: error_message = f"Error creating Bokeh plot: {str(e)}" logging.exception(error_message) return None # Or return a placeholder plot def create_plotly_plot(graph: nx.Graph, layout_type: str): """Creates a Plotly plot of the given networkx graph.""" try: if not nx.is_graph(graph): # Add this check logging.error("Error: create_plotly_plot received an invalid networkx graph") return None # Or return a placeholder plot import plotly.graph_objects as go if layout_type == 'spring': pos = nx.spring_layout(graph, seed=42) elif layout_type == 'fruchterman_reingold': pos = nx.fruchterman_reingold_layout(graph, seed=42) elif layout_type == 'circular': pos = nx.circular_layout(graph) elif layout_type == 'random': pos = nx.random_layout(graph, seed=42) elif layout_type == 'spectral': pos = nx.spectral_layout(graph) elif layout_type == 'shell': pos = nx.shell_layout(graph) else: pos = nx.spring_layout(graph, seed=42) # Default layout edge_x = [] edge_y = [] for edge in graph.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.append(x0) edge_y.append(y0) edge_x.append(x1) edge_y.append(y1) edge_x.append(None) edge_y.append(None) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines') node_x = [] node_y = [] for node in graph.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) node_trace = go.Scatter( x=node_x, y=node_y, mode='markers', hoverinfo='text', marker=dict( showscale=False, colorscale='YlGnBu', reversescale=True, color=[], size=10, line_width=2)) node_adjacencies = [] node_text = [] for node, adjacencies in enumerate(graph.adjacency()): node_adjacencies.append(len(adjacencies[1])) node_text.append(f"{adjacencies[0]} (# of connections: {len(adjacencies[1])})") node_trace.marker.color = node_adjacencies node_trace.text = node_text fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( title='Knowledge Graph', titlefont_size=16, showlegend=False, hovermode='closest', margin=dict(b=20,l=5,r=5,t=40), annotations=[dict( text="Replace with your attribution text", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002)], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)) ) return fig except Exception as e: error_message = f"Error creating Plotly plot: {str(e)}" logging.exception(error_message) return None # Or return a placeholder plot # samples.py from dataclasses import dataclass @dataclass class Sample: text_input: str entity_types: str predicates: str snippets = { "Sample 1": Sample( text_input="Alice knows Bob.", entity_types="Person", predicates="knows" ), "Sample 2": Sample( text_input="The cat sat on the mat.", entity_types="Animal, Object", predicates="sat on" ), "Sample 3": Sample( text_input="Global warming is causing significant changes to Earth's climate. The average global \ temperature has increased by approximately 1C since the pre-industrial era. This warming is \ primarily caused by human activities, particularly the emission of greenhouse gases like carbon dioxide. \ The Paris Agreement, signed in 2015, aims to limit global temperature increase to well below 2°C above \ pre-industrial levels. To achieve this goal, many countries are implementing policies to reduce carbon \ emissions and transition to renewable energy sources.", entity_types="Event, Measure, Activity, Agreement", predicates="causes, increased by, aims to limit" ) } # --- Gradio Interface Code --- WORD_LIMIT = 300 def process_text(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str): print(f"process_text input:\ntext: {text}\nentity_types: {entity_types}\npredicates: {predicates}") if not text: return None, None, "Please enter some text." words = text.split() if len(words) > WORD_LIMIT: return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}" entity_types_list = [et.strip() for et in entity_types.split(",") if et.strip()] predicates_list = [p.strip() for p in predicates.split(",") if p.strip()] if not entity_types_list: return None, None, "Please enter at least one entity type." if not predicates_list: return None, None, "Please enter at least one predicate." try: prediction = triplextract(text, entity_types_list, predicates_list) # Pass lists, not strings if prediction and prediction.startswith("Error"): # Check for errors return None, None, prediction entities, relationships = parse_triples(prediction) if not entities and not relationships: return None, None, "No entities or relationships found. Try different text or check your input." G = create_graph(entities, relationships) if visualization_type == 'Bokeh': fig = create_bokeh_plot(G, layout_type) else: fig = create_plotly_plot(G, layout_type) output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}" return G, fig, output_text except Exception as e: error_message = f"Error in process_text: {str(e)}" logging.exception(error_message) return None, None, f"An error occurred: {str(e)}" def update_graph(G: nx.Graph, layout_type: str, visualization_type: str): if G is None: return None, "Please process text first." try: if visualization_type == 'Bokeh': fig = create_bokeh_plot(G, layout_type) else: fig = create_plotly_plot(G, layout_type) return fig, "" except Exception as e: error_message = f"Error in update_graph: {str(e)}" logging.exception(error_message) return None, f"An error occurred while updating the graph: {str(e)}" def update_inputs(sample_name: str): sample = snippets[sample_name] return sample.text_input, sample.entity_types, sample.predicates with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown("# Knowledge Graph Extractor") # Provide a fallback in case snippets is empty sample_keys = list(snippets.keys()) default_sample_name = random.choice(sample_keys) if sample_keys else "" default_sample = snippets.get(default_sample_name) if default_sample_name else None # Safely get the sample with gr.Row(): with gr.Column(scale=1): sample_dropdown = gr.Dropdown(choices=sample_keys, label="Select Sample", value=default_sample_name) input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input if default_sample else "") entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types if default_sample else "") predicates = gr.Textbox(label="Predicates", value=default_sample.predicates if default_sample else "") layout_type = gr.Dropdown( choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'], label="Layout Type", value='spring') visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh') process_btn = gr.Button("Process Text") with gr.Column(scale=2): output_graph = gr.Plot(label="Knowledge Graph") error_message = gr.Textbox(label="Textual Output") graph_state = gr.State(None) def process_and_update(text: str, entity_types: str, predicates: str, layout_type: str, visualization_type: str): G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type) print(f"process_and_update: G = {G}") # Debug return G, fig, output def update_graph_wrapper(G: nx.Graph, layout_type: str, visualization_type: str): print(f"update_graph_wrapper: G = {G}") # Debug if G is not None: fig, _ = update_graph(G, layout_type, visualization_type) return fig sample_dropdown.change(update_inputs, inputs=[sample_dropdown], outputs=[input_text, entity_types, predicates]) process_btn.click(process_and_update, inputs=[input_text, entity_types, predicates, layout_type, visualization_type], outputs=[graph_state, output_graph, error_message]) layout_type.change(update_graph_wrapper, inputs=[graph_state, layout_type, visualization_type], outputs=[output_graph]) visualization_type.change(update_graph_wrapper, inputs=[graph_state, layout_type, visualization_type], outputs=[output_graph]) if __name__ == "__main__": demo.launch(share=True)