Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph | |
import rapidjson | |
from pyvis.network import Network | |
import networkx as nx | |
import spacy | |
from spacy import displacy | |
from spacy.tokens import Span | |
import random | |
import os | |
import pickle | |
# Constants | |
TITLE = "π GraphMind: Phi-3 Instruct Graph Explorer" | |
SUBTITLE = "β¨ Extract and visualize knowledge graphs from any text in multiple languages" | |
# Enhanced Custom CSS for styling with improved visuals | |
CUSTOM_CSS = """ | |
.gradio-container { | |
font-family: 'Inter', 'Segoe UI', Roboto, sans-serif; | |
background: linear-gradient(to bottom, #f9fafb, #f3f4f6); | |
} | |
.gr-button-primary { | |
background-color: #4f46e5 !important; | |
border: none !important; | |
color: white !important; | |
border-radius: 8px !important; | |
} | |
.gr-button-primary:hover { | |
background-color: #4338ca !important; | |
transform: translateY(-1px); | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); | |
} | |
.gr-button-secondary { | |
border-color: #4f46e5 !important; | |
color: #4f46e5 !important; | |
border-radius: 8px !important; | |
} | |
.gr-button-secondary:hover { | |
background-color: #eef2ff !important; | |
transform: translateY(-1px); | |
} | |
.gr-box, .gr-input, .gr-textarea, .gr-dropdown { | |
border-radius: 8px !important; | |
border: 1px solid #e5e7eb !important; | |
box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05) !important; | |
} | |
.gr-padded { | |
padding: 16px !important; | |
} | |
.gr-form { | |
border: none !important; | |
background: transparent !important; | |
} | |
.gr-input:focus, .gr-textarea:focus, .gr-dropdown:focus { | |
border-color: #4f46e5 !important; | |
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.2) !important; | |
} | |
.gr-panel { | |
border-radius: 12px !important; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; | |
background-color: white !important; | |
} | |
.gr-heading { | |
font-weight: 700 !important; | |
color: #111827 !important; | |
} | |
.gr-examples-table { | |
border-radius: 8px !important; | |
overflow: hidden !important; | |
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06) !important; | |
} | |
.gr-prose p { | |
margin-bottom: 0.75rem !important; | |
color: #4b5563 !important; | |
} | |
.gr-prose h1, .gr-prose h2, .gr-prose h3 { | |
font-weight: 700 !important; | |
color: #111827 !important; | |
} | |
.gr-tab { | |
border-radius: 8px 8px 0 0 !important; | |
} | |
.gr-tab-selected { | |
border-color: #4f46e5 !important; | |
color: #4f46e5 !important; | |
font-weight: 600 !important; | |
} | |
.visualization-container { | |
min-height: 600px !important; | |
margin-top: 2rem !important; | |
margin-bottom: 2rem !important; | |
border-radius: 16px !important; | |
overflow: hidden !important; | |
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -4px rgba(0, 0, 0, 0.1) !important; | |
} | |
.sidebar-container { | |
background-color: white !important; | |
border-radius: 12px !important; | |
padding: 16px !important; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; | |
} | |
.app-title { | |
background: linear-gradient(90deg, #4f46e5, #8b5cf6) !important; | |
-webkit-background-clip: text !important; | |
-webkit-text-fill-color: transparent !important; | |
font-weight: 800 !important; | |
font-size: 2.25rem !important; | |
margin-bottom: 0.5rem !important; | |
} | |
.app-subtitle { | |
color: #6b7280 !important; | |
font-size: 1.25rem !important; | |
margin-bottom: 2rem !important; | |
} | |
.graph-iframe iframe { | |
width: 100% !important; | |
height: 700px !important; | |
border-radius: 12px !important; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; | |
} | |
.results-container { | |
background-color: #f8fafc !important; | |
padding: 20px !important; | |
border-radius: 12px !important; | |
margin-top: 1rem !important; | |
margin-bottom: 1rem !important; | |
border: 1px solid #e2e8f0 !important; | |
} | |
.language-badge { | |
display: inline-block !important; | |
background-color: #4f46e5 !important; | |
color: white !important; | |
padding: 4px 12px !important; | |
border-radius: 16px !important; | |
font-weight: 600 !important; | |
font-size: 0.875rem !important; | |
margin-right: 8px !important; | |
} | |
""" | |
# Cache directory and file paths | |
CACHE_DIR = "cache" | |
EXAMPLE_CACHE_FILE = os.path.join(CACHE_DIR, "first_example_cache.pkl") | |
# Create cache directory if it doesn't exist | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Color utilities | |
def get_random_light_color(): | |
r = random.randint(140, 255) | |
g = random.randint(140, 255) | |
b = random.randint(140, 255) | |
return f"#{r:02x}{g:02x}{b:02x}" | |
# Text preprocessing | |
def handle_text(text): | |
return " ".join(text.split()) | |
# Main processing functions | |
def extract(text, model): | |
try: | |
model = Phi3InstructGraph(model=model) | |
result = model.extract(text) | |
return rapidjson.loads(result) | |
except Exception as e: | |
raise gr.Error(f"Extraction error: {str(e)}") | |
def find_token_indices(doc, substring, text): | |
result = [] | |
start_index = text.find(substring) | |
while start_index != -1: | |
end_index = start_index + len(substring) | |
start_token = None | |
end_token = None | |
for token in doc: | |
if token.idx == start_index: | |
start_token = token.i | |
if token.idx + len(token) == end_index: | |
end_token = token.i + 1 | |
if start_token is not None and end_token is not None: | |
result.append({ | |
"start": start_token, | |
"end": end_token | |
}) | |
# Search for next occurrence | |
start_index = text.find(substring, end_index) | |
return result | |
def create_custom_entity_viz(data, full_text): | |
nlp = spacy.blank("xx") | |
doc = nlp(full_text) | |
spans = [] | |
colors = {} | |
for node in data["nodes"]: | |
entity_spans = find_token_indices(doc, node["id"], full_text) | |
for dataentity in entity_spans: | |
start = dataentity["start"] | |
end = dataentity["end"] | |
if start < len(doc) and end <= len(doc): | |
# Check for overlapping spans | |
overlapping = any(s.start < end and start < s.end for s in spans) | |
if not overlapping: | |
span = Span(doc, start, end, label=node["type"]) | |
spans.append(span) | |
if node["type"] not in colors: | |
colors[node["type"]] = get_random_light_color() | |
doc.set_ents(spans, default="unmodified") | |
doc.spans["sc"] = spans | |
options = { | |
"colors": colors, | |
"ents": list(colors.keys()), | |
"style": "ent", | |
"manual": True | |
} | |
html = displacy.render(doc, style="span", options=options) | |
# Add custom styling to the entity visualization | |
styled_html = f""" | |
<div style="padding: 20px; border-radius: 12px; background-color: white; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);"> | |
{html} | |
</div> | |
""" | |
return styled_html | |
def create_graph(json_data): | |
G = nx.Graph() | |
# Add nodes with tooltips | |
for node in json_data['nodes']: | |
G.add_node(node['id'], title=f"{node['type']}: {node['detailed_type']}") | |
# Add edges with labels | |
for edge in json_data['edges']: | |
G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label']) | |
# Create network visualization | |
nt = Network( | |
width="100%", | |
height="700px", | |
directed=True, | |
notebook=False, | |
bgcolor="#f8fafc", | |
font_color="#1e293b" | |
) | |
# Configure network display | |
nt.from_nx(G) | |
nt.barnes_hut( | |
gravity=-3000, | |
central_gravity=0.3, | |
spring_length=50, | |
spring_strength=0.001, | |
damping=0.09, | |
overlap=0, | |
) | |
# Customize edge appearance | |
for edge in nt.edges: | |
edge['width'] = 2 | |
edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}} | |
edge['color'] = {'color': '#6366f1', 'highlight': '#4f46e5'} | |
edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Arial'} | |
# Customize node appearance | |
for node in nt.nodes: | |
node['color'] = {'background': '#e0e7ff', 'border': '#6366f1', 'highlight': {'background': '#c7d2fe', 'border': '#4f46e5'}} | |
node['font'] = {'size': 14, 'color': '#1e293b'} | |
node['shape'] = 'dot' | |
node['size'] = 25 | |
# Generate HTML with iframe to isolate styles | |
html = nt.generate_html() | |
html = html.replace("'", '"') | |
return f"""<iframe style="width: 100%; height: 700px; margin: 0 auto; border-radius: 12px; box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -4px rgba(0, 0, 0, 0.1);" | |
name="result" allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;" | |
sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" | |
def process_and_visualize(text, model, progress=gr.Progress()): | |
if not text or not model: | |
raise gr.Error("β οΈ Both text and model must be provided.") | |
# Check if we're processing the first example for caching | |
is_first_example = text == EXAMPLES[0][0] | |
# Try to load from cache if it's the first example | |
if is_first_example and os.path.exists(EXAMPLE_CACHE_FILE): | |
try: | |
progress(0.3, desc="Loading from cache...") | |
with open(EXAMPLE_CACHE_FILE, 'rb') as f: | |
cache_data = pickle.load(f) | |
progress(1.0, desc="Loaded from cache!") | |
return cache_data["graph_html"], cache_data["entities_viz"], cache_data["json_data"], cache_data["stats"] | |
except Exception as e: | |
print(f"Cache loading error: {str(e)}") | |
# Continue with normal processing if cache fails | |
progress(0, desc="Starting extraction...") | |
json_data = extract(text, model) | |
progress(0.5, desc="Creating entity visualization...") | |
entities_viz = create_custom_entity_viz(json_data, text) | |
progress(0.8, desc="Building knowledge graph...") | |
graph_html = create_graph(json_data) | |
node_count = len(json_data["nodes"]) | |
edge_count = len(json_data["edges"]) | |
stats = f"π Extracted {node_count} entities and {edge_count} relationships" | |
# Save to cache if it's the first example | |
if is_first_example: | |
try: | |
cache_data = { | |
"graph_html": graph_html, | |
"entities_viz": entities_viz, | |
"json_data": json_data, | |
"stats": stats | |
} | |
with open(EXAMPLE_CACHE_FILE, 'wb') as f: | |
pickle.dump(cache_data, f) | |
except Exception as e: | |
print(f"Cache saving error: {str(e)}") | |
progress(1.0, desc="Complete!") | |
return graph_html, entities_viz, json_data, stats | |
# Example texts in different languages | |
EXAMPLES = [ | |
[handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing | |
lead singer Steven Tyler's unrecoverable vocal cord injury. | |
The decision comes after months of unsuccessful treatment for Tyler's fractured larynx, | |
which he suffered in September 2023.""")], | |
[handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual | |
court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI) | |
in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe, | |
pleaded not guilty to the charges.""")], | |
[handle_text("""μΈκ³μ μΈ κΈ°μ κΈ°μ μΌμ±μ μλ μλ‘μ΄ μΈκ³΅μ§λ₯ κΈ°λ° μ€λ§νΈν°μ μ¬ν΄ νλ°κΈ°μ μΆμν μμ μ΄λΌκ³ λ°ννλ€. | |
μ΄ μ€λ§νΈν°μ νμ¬ κ°λ° μ€μΈ κ°€λμ μ리μ¦μ μ΅μ μμΌλ‘, κ°λ ₯ν AI κΈ°λ₯κ³Ό νμ μ μΈ μΉ΄λ©λΌ μμ€ν μ νμ¬ν κ²μΌλ‘ μλ €μ‘λ€. | |
μΌμ±μ μμ CEOλ μ΄λ² μ μ νμ΄ μ€λ§νΈν° μμ₯μ μλ‘μ΄ νμ μ κ°μ Έμ¬ κ²μ΄λΌκ³ μ λ§νλ€.""")], | |
[handle_text("""νκ΅ μν 'κΈ°μμΆ©'μ 2020λ μμΉ΄λ°λ―Έ μμμμμ μνμ, κ°λ μ, κ°λ³Έμ, κ΅μ μνμ λ± 4κ° λΆλ¬Έμ μμνλ©° μμ¬λ₯Ό μλ‘ μΌλ€. | |
λ΄μ€νΈ κ°λ μ΄ μ°μΆν μ΄ μνλ νκ΅ μν μ΅μ΄λ‘ μΉΈ μνμ ν©κΈμ’ λ €μλ μμνμΌλ©°, μ μΈκ³μ μΌλ‘ μμ²λ ν₯νκ³Ό | |
νλ¨μ νΈνμ λ°μλ€.""")] | |
] | |
# Function to preprocess the first example when the app starts | |
def generate_first_example_cache(): | |
"""Generate cache for the first example if it doesn't exist""" | |
if not os.path.exists(EXAMPLE_CACHE_FILE): | |
print("Generating cache for first example...") | |
try: | |
text = EXAMPLES[0][0] | |
model = MODEL_LIST[0] if MODEL_LIST else None | |
if model: | |
# Extract data | |
json_data = extract(text, model) | |
entities_viz = create_custom_entity_viz(json_data, text) | |
graph_html = create_graph(json_data) | |
node_count = len(json_data["nodes"]) | |
edge_count = len(json_data["edges"]) | |
stats = f"π Extracted {node_count} entities and {edge_count} relationships" | |
# Save to cache | |
cache_data = { | |
"graph_html": graph_html, | |
"entities_viz": entities_viz, | |
"json_data": json_data, | |
"stats": stats | |
} | |
with open(EXAMPLE_CACHE_FILE, 'wb') as f: | |
pickle.dump(cache_data, f) | |
print("First example cache generated successfully") | |
return cache_data | |
except Exception as e: | |
print(f"Error generating first example cache: {str(e)}") | |
else: | |
print("First example cache already exists") | |
try: | |
with open(EXAMPLE_CACHE_FILE, 'rb') as f: | |
return pickle.load(f) | |
except Exception as e: | |
print(f"Error loading existing cache: {str(e)}") | |
return None | |
def create_ui(): | |
# Try to generate/load the first example cache | |
first_example_cache = generate_first_example_cache() | |
with gr.Blocks(css=CUSTOM_CSS, title=TITLE) as demo: | |
# Header with enhanced styling | |
with gr.Row(elem_classes=["header-container"]): | |
with gr.Column(): | |
gr.Markdown(f"<h1 class='app-title'>{TITLE}</h1>") | |
gr.Markdown(f"<p class='app-subtitle'>{SUBTITLE}</p>") | |
with gr.Row(): | |
gr.Markdown("<span class='language-badge'>English</span><span class='language-badge'>Korean</span><span class='language-badge'>+ More</span>") | |
# Main content area - redesigned layout | |
with gr.Row(): | |
# Left panel - Input controls | |
with gr.Column(scale=1, elem_classes=["sidebar-container"]): | |
input_model = gr.Dropdown( | |
MODEL_LIST, | |
label="π€ Select Model", | |
info="Choose a model to process your text", | |
value=MODEL_LIST[0] if MODEL_LIST else None, | |
elem_classes=["control-item"] | |
) | |
input_text = gr.TextArea( | |
label="π Input Text", | |
info="Enter text in any language to extract a knowledge graph", | |
placeholder="Enter text here...", | |
lines=8, | |
value=EXAMPLES[0][0], # Pre-fill with first example | |
elem_classes=["control-item"] | |
) | |
with gr.Row(): | |
submit_button = gr.Button("π Extract & Visualize", variant="primary", scale=2) | |
clear_button = gr.Button("π Clear", variant="secondary", scale=1) | |
# Statistics will appear here | |
stats_output = gr.Markdown("", label="π Analysis Results", elem_classes=["results-container"]) | |
# Right panel - Examples moved to right side | |
with gr.Column(scale=1, elem_classes=["sidebar-container"]): | |
gr.Markdown("<h3>π Example Texts</h3>") | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=input_text, | |
label="", | |
elem_classes=["examples-panel"] | |
) | |
# JSON output moved to right side as well | |
with gr.Accordion("π JSON Data", open=False): | |
output_json = gr.JSON(label="") | |
# Full width visualization area at the bottom | |
with gr.Row(elem_classes=["visualization-container"]): | |
with gr.Column(): | |
# Tab container for visualizations | |
with gr.Tabs(): | |
with gr.Tab("𧩠Knowledge Graph"): | |
output_graph = gr.HTML(label="", elem_classes=["graph-iframe"]) | |
with gr.Tab("π·οΈ Entity Recognition"): | |
output_entity_viz = gr.HTML(label="") | |
# Functionality | |
submit_button.click( | |
fn=process_and_visualize, | |
inputs=[input_text, input_model], | |
outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
) | |
clear_button.click( | |
fn=lambda: [None, None, None, ""], | |
inputs=[], | |
outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
) | |
# Set initial values from cache if available | |
if first_example_cache: | |
# Use this to set initial values when the app loads | |
demo.load( | |
lambda: [ | |
first_example_cache["graph_html"], | |
first_example_cache["entities_viz"], | |
first_example_cache["json_data"], | |
first_example_cache["stats"] | |
], | |
inputs=None, | |
outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
) | |
# Footer | |
with gr.Row(elem_classes=["footer-container"]): | |
gr.Markdown("---") | |
gr.Markdown("π **Instructions:** Enter text in any language, select a model, and click 'Extract & Visualize' to generate a knowledge graph.") | |
gr.Markdown("π οΈ Powered by Phi-3 Instruct Graph | Emergent Methods") | |
return demo | |
demo = create_ui() | |
demo.launch(share=False) |