Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
import json | |
from datetime import datetime | |
import sqlite3 | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
import re | |
# Initialize NLP pipelines | |
ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english") | |
classifier = pipeline("zero-shot-classification") | |
class OntologyRegistry: | |
def __init__(self): | |
self.temporal_patterns = [ | |
r'\b\d{1,2}:\d{2}\s*(?:AM|PM|am|pm)?\b', | |
r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]* \d{1,2}(?:st|nd|rd|th)?,? \d{4}\b', | |
r'\btomorrow\b', | |
r'\bin \d+ (?:days?|weeks?|months?)\b' | |
] | |
self.location_patterns = [ | |
r'\b(?:in|at|from|to) ([A-Z][a-zA-Z]+(,? [A-Z]{2})?)\b', | |
r'\b[A-Z][a-zA-Z]+ Base\b', | |
r'\bHeadquarters\b', | |
r'\bHQ\b' | |
] | |
self.entity_types = { | |
'PER': 'person', | |
'ORG': 'organization', | |
'LOC': 'location', | |
'MISC': 'miscellaneous' | |
} | |
def validate_pattern(self, text, pattern_type): | |
patterns = getattr(self, f"{pattern_type}_patterns", []) | |
matches = [] | |
for pattern in patterns: | |
matches.extend(re.finditer(pattern, text)) | |
return [m.group() for m in matches] | |
class RelationshipEngine: | |
def __init__(self, db_path=':memory:'): | |
self.conn = sqlite3.connect(db_path, check_same_thread=False) # Add this flag | |
self.setup_database() | |
def setup_database(self): | |
# Events table | |
self.conn.execute(''' | |
CREATE TABLE IF NOT EXISTS events ( | |
id INTEGER PRIMARY KEY, | |
text TEXT, | |
timestamp DATETIME, | |
confidence REAL | |
) | |
''') | |
# Entities table | |
self.conn.execute(''' | |
CREATE TABLE IF NOT EXISTS entities ( | |
id INTEGER PRIMARY KEY, | |
entity_text TEXT, | |
entity_type TEXT, -- person, organization, location, hashtag, temporal | |
first_seen DATETIME, | |
last_seen DATETIME, | |
frequency INTEGER DEFAULT 1, | |
confidence REAL | |
) | |
''') | |
# Event-Entity relationships | |
self.conn.execute(''' | |
CREATE TABLE IF NOT EXISTS event_entities ( | |
event_id INTEGER, | |
entity_id INTEGER, | |
FOREIGN KEY (event_id) REFERENCES events(id), | |
FOREIGN KEY (entity_id) REFERENCES entities(id), | |
PRIMARY KEY (event_id, entity_id) | |
) | |
''') | |
# Entity relationships (e.g., person-organization affiliations) | |
self.conn.execute(''' | |
CREATE TABLE IF NOT EXISTS entity_relationships ( | |
id INTEGER PRIMARY KEY, | |
source_entity_id INTEGER, | |
target_entity_id INTEGER, | |
relationship_type TEXT, | |
confidence REAL, | |
first_seen DATETIME, | |
last_seen DATETIME, | |
FOREIGN KEY (source_entity_id) REFERENCES entities(id), | |
FOREIGN KEY (target_entity_id) REFERENCES entities(id) | |
) | |
''') | |
self.conn.commit() | |
def store_entities(self, event_id, entities_dict): | |
now = datetime.now().isoformat() | |
for entity_type, entities in entities_dict.items(): | |
if not isinstance(entities, list): | |
continue | |
for entity_text in entities: | |
# Check if entity exists | |
cursor = self.conn.execute( | |
'SELECT id, frequency FROM entities WHERE entity_text = ? AND entity_type = ?', | |
(entity_text, entity_type) | |
) | |
result = cursor.fetchone() | |
if result: | |
# Update existing entity | |
entity_id, freq = result | |
self.conn.execute(''' | |
UPDATE entities | |
SET frequency = ?, last_seen = ? | |
WHERE id = ? | |
''', (freq + 1, now, entity_id)) | |
else: | |
# Insert new entity | |
cursor = self.conn.execute(''' | |
INSERT INTO entities (entity_text, entity_type, first_seen, last_seen, confidence) | |
VALUES (?, ?, ?, ?, ?) | |
''', (entity_text, entity_type, now, now, 1.0)) | |
entity_id = cursor.lastrowid | |
# Create event-entity relationship | |
self.conn.execute(''' | |
INSERT OR IGNORE INTO event_entities (event_id, entity_id) | |
VALUES (?, ?) | |
''', (event_id, entity_id)) | |
self.conn.commit() | |
def find_related_events(self, event_data): | |
# Find events sharing entities | |
entity_texts = [] | |
for entity_type, entities in event_data.get('entities', {}).items(): | |
if isinstance(entities, list): | |
entity_texts.extend(entities) | |
if not entity_texts: | |
return [] | |
# Build query using entity relationships | |
query = ''' | |
SELECT DISTINCT e.*, COUNT(ee.entity_id) as shared_entities | |
FROM events e | |
JOIN event_entities ee ON e.id = ee.event_id | |
JOIN entities ent ON ee.entity_id = ent.id | |
WHERE ent.entity_text IN ({}) | |
GROUP BY e.id | |
ORDER BY shared_entities DESC, e.timestamp DESC | |
LIMIT 5 | |
'''.format(','.join('?' * len(entity_texts))) | |
cursor = self.conn.execute(query, entity_texts) | |
return cursor.fetchall() | |
def find_entity_relationships(self, entity_id): | |
# Find direct relationships | |
query = ''' | |
SELECT er.*, | |
e1.entity_text as source_text, e1.entity_type as source_type, | |
e2.entity_text as target_text, e2.entity_type as target_type | |
FROM entity_relationships er | |
JOIN entities e1 ON er.source_entity_id = e1.id | |
JOIN entities e2 ON er.target_entity_id = e2.id | |
WHERE er.source_entity_id = ? OR er.target_entity_id = ? | |
''' | |
cursor = self.conn.execute(query, (entity_id, entity_id)) | |
return cursor.fetchall() | |
def update_entity_relationships(self, event_id): | |
# Find all entities in the event | |
query = ''' | |
SELECT e.id, e.entity_text, e.entity_type | |
FROM entities e | |
JOIN event_entities ee ON e.id = ee.entity_id | |
WHERE ee.event_id = ? | |
''' | |
cursor = self.conn.execute(query, (event_id,)) | |
entities = cursor.fetchall() | |
now = datetime.now().isoformat() | |
# Create/update relationships between entities in same event | |
for i, entity1 in enumerate(entities): | |
for entity2 in entities[i+1:]: | |
# Skip same entity type relationships | |
if entity1[2] == entity2[2]: | |
continue | |
relationship_type = f"{entity1[2]}_to_{entity2[2]}" | |
# Check if relationship exists | |
cursor = self.conn.execute(''' | |
SELECT id FROM entity_relationships | |
WHERE (source_entity_id = ? AND target_entity_id = ?) | |
OR (source_entity_id = ? AND target_entity_id = ?) | |
''', (entity1[0], entity2[0], entity2[0], entity1[0])) | |
result = cursor.fetchone() | |
if result: | |
# Update existing relationship | |
self.conn.execute(''' | |
UPDATE entity_relationships | |
SET last_seen = ?, confidence = confidence + 0.1 | |
WHERE id = ? | |
''', (now, result[0])) | |
else: | |
# Create new relationship | |
self.conn.execute(''' | |
INSERT INTO entity_relationships | |
(source_entity_id, target_entity_id, relationship_type, confidence, first_seen, last_seen) | |
VALUES (?, ?, ?, ?, ?, ?) | |
''', (entity1[0], entity2[0], relationship_type, 0.5, now, now)) | |
self.conn.commit() | |
class EventAnalyzer: | |
def __init__(self): | |
self.ontology = OntologyRegistry() | |
self.relationship_engine = RelationshipEngine() | |
self.executor = ThreadPoolExecutor(max_workers=3) | |
async def extract_entities(self, text): | |
def _extract(): | |
return ner_pipeline(text) | |
# Run NER in thread pool | |
ner_results = await asyncio.get_event_loop().run_in_executor( | |
self.executor, _extract | |
) | |
entities = { | |
"people": [], | |
"organizations": [], | |
"locations": [], | |
"hashtags": [word for word in text.split() if word.startswith('#')] | |
} | |
for item in ner_results: | |
if item["entity"].endswith("PER"): | |
entities["people"].append(item["word"]) | |
elif item["entity"].endswith("ORG"): | |
entities["organizations"].append(item["word"]) | |
elif item["entity"].endswith("LOC"): | |
entities["locations"].append(item["word"]) | |
return entities | |
def extract_temporal(self, text): | |
return self.ontology.validate_pattern(text, 'temporal') | |
async def extract_locations(self, text): | |
entities = await self.extract_entities(text) | |
ml_locations = entities.get('locations', []) | |
pattern_locations = self.ontology.validate_pattern(text, 'location') | |
return list(set(ml_locations + pattern_locations)) | |
def calculate_confidence(self, entities, temporal_data, related_events): | |
# Base confidence from entity presence | |
base_confidence = min(1.0, ( | |
0.2 * bool(entities["people"]) + | |
0.2 * bool(entities["organizations"]) + | |
0.3 * bool(entities["locations"]) + | |
0.3 * bool(temporal_data) | |
)) | |
# Adjust confidence based on entity frequency | |
entity_params = [ | |
*entities["people"], | |
*entities["organizations"], | |
*entities["locations"] | |
] | |
cursor = self.relationship_engine.conn.execute( | |
f''' | |
SELECT AVG(frequency) as avg_freq | |
FROM entities | |
WHERE entity_text IN ( | |
SELECT DISTINCT entity_text | |
FROM entities | |
WHERE entity_text IN ({','.join(['?']*len(entity_params))}) | |
) | |
''', | |
entity_params # Pass parameters here | |
) | |
avg_frequency = cursor.fetchone()[0] or 1 | |
frequency_boost = min(0.2, (avg_frequency - 1) * 0.05) # Max 0.2 boost for frequency | |
# Adjust confidence based on relationships | |
relationship_confidence = 0 | |
if related_events: | |
relationship_scores = [] | |
for event in related_events: | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT COUNT(*) as shared_entities | |
FROM event_entities ee1 | |
JOIN event_entities ee2 ON ee1.entity_id = ee2.entity_id | |
WHERE ee1.event_id = ? AND ee2.event_id = ? | |
''', (event[0], event[0])) # event[0] is the event_id | |
shared_count = cursor.fetchone()[0] | |
relationship_scores.append(min(0.3, shared_count * 0.1)) # Max 0.3 boost per relationship | |
if relationship_scores: | |
relationship_confidence = max(relationship_scores) | |
final_confidence = min(1.0, base_confidence + frequency_boost + relationship_confidence) | |
return final_confidence | |
async def analyze_event(self, text): | |
try: | |
# Parallel extraction | |
entities_future = self.extract_entities(text) | |
temporal_data = self.extract_temporal(text) | |
locations_future = self.extract_locations(text) | |
# Gather async results | |
entities, locations = await asyncio.gather( | |
entities_future, locations_future | |
) | |
# Add temporal and locations to entities | |
entities['locations'] = locations | |
entities['temporal'] = temporal_data | |
# Find related events | |
related_events = self.relationship_engine.find_related_events({ | |
'text': text, | |
'entities': entities | |
}) | |
# Calculate confidence with enhanced logic | |
confidence = self.calculate_confidence(entities, temporal_data, related_events) | |
# Store event if confidence meets threshold | |
cursor = None | |
if confidence >= 0.6: | |
cursor = self.relationship_engine.conn.execute( | |
'INSERT INTO events (text, timestamp, confidence) VALUES (?, ?, ?)', | |
(text, datetime.now().isoformat(), confidence) | |
) | |
event_id = cursor.lastrowid | |
# Store entities and their relationships | |
self.relationship_engine.store_entities(event_id, { | |
'person': entities['people'], | |
'organization': entities['organizations'], | |
'location': entities['locations'], | |
'temporal': temporal_data, | |
'hashtag': entities['hashtags'] | |
}) | |
# Update entity relationships | |
self.relationship_engine.update_entity_relationships(event_id) | |
self.relationship_engine.conn.commit() | |
# Get entity relationships for rich output | |
entity_relationships = [] | |
if cursor and cursor.lastrowid: | |
query = ''' | |
SELECT DISTINCT er.*, | |
e1.entity_text as source_text, e1.entity_type as source_type, | |
e2.entity_text as target_text, e2.entity_type as target_type | |
FROM event_entities ee | |
JOIN entity_relationships er ON ee.entity_id IN (er.source_entity_id, er.target_entity_id) | |
JOIN entities e1 ON er.source_entity_id = e1.id | |
JOIN entities e2 ON er.target_entity_id = e2.id | |
WHERE ee.event_id = ? | |
''' | |
entity_relationships = self.relationship_engine.conn.execute(query, (cursor.lastrowid,)).fetchall() | |
result = { | |
"text": text, | |
"entities": entities, | |
"confidence": confidence, | |
"verification_needed": confidence < 0.6, | |
"related_events": [ | |
{ | |
"text": event[1], | |
"timestamp": event[2], | |
"confidence": event[3], | |
"shared_entities": event[4] if len(event) > 4 else None | |
} | |
for event in related_events | |
], | |
"entity_relationships": [ | |
{ | |
"type": rel[3], | |
"source": rel[6], | |
"target": rel[8], | |
"confidence": rel[4] | |
} | |
for rel in entity_relationships | |
] if entity_relationships else [] | |
} | |
return result | |
except Exception as e: | |
return {"error": str(e)} | |
def get_entity_statistics(self): | |
"""Get statistics about stored entities and relationships""" | |
stats = {} | |
# Entity counts by type | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT entity_type, COUNT(*) as count, AVG(frequency) as avg_frequency | |
FROM entities | |
GROUP BY entity_type | |
''') | |
stats['entity_counts'] = cursor.fetchall() | |
# Most frequent entities | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT entity_text, entity_type, frequency | |
FROM entities | |
ORDER BY frequency DESC | |
LIMIT 10 | |
''') | |
stats['frequent_entities'] = cursor.fetchall() | |
# Relationship statistics | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT relationship_type, COUNT(*) as count, AVG(confidence) as avg_confidence | |
FROM entity_relationships | |
GROUP BY relationship_type | |
''') | |
stats['relationship_stats'] = cursor.fetchall() | |
return stats | |
# Initialize analyzer | |
analyzer = EventAnalyzer() | |
# Custom CSS for UI | |
css = """ | |
.container { max-width: 1200px; margin: auto; padding: 20px; } | |
.results { padding: 20px; border: 1px solid #ddd; border-radius: 8px; margin-top: 20px; } | |
.confidence-high { color: #22c55e; font-weight: bold; } | |
.confidence-low { color: #f97316; font-weight: bold; } | |
.entity-section { margin: 15px 0; } | |
.alert-warning { background: #fff3cd; padding: 10px; border-radius: 5px; margin: 10px 0; } | |
.alert-success { background: #d1fae5; padding: 10px; border-radius: 5px; margin: 10px 0; } | |
.related-events { background: #f3f4f6; padding: 15px; border-radius: 5px; margin-top: 15px; } | |
""" | |
def format_results(analysis_result): | |
if "error" in analysis_result: | |
return f"<div style='color: red'>Error: {analysis_result['error']}</div>" | |
confidence_class = "confidence-high" if analysis_result["confidence"] >= 0.6 else "confidence-low" | |
html = f""" | |
<div class="results"> | |
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 20px;"> | |
<h3 style="margin: 0;">Analysis Results</h3> | |
<div> | |
Confidence Score: <span class="{confidence_class}">{int(analysis_result['confidence'] * 100)}%</span> | |
</div> | |
</div> | |
{f''' | |
<div class="alert-warning"> | |
⚠ <strong>Verification Required:</strong> Low confidence score detected. Please verify the extracted information. | |
</div> | |
''' if analysis_result["verification_needed"] else ''} | |
<div class="grid grid-cols-2 gap-4"> | |
<div class="space-y-4"> | |
<div class="entity-section"> | |
<h4>People Detected</h4> | |
<ul>{''.join(f'<li>{person}</li>' for person in analysis_result['entities']['people']) or '<li>None detected</li>'}</ul> | |
</div> | |
<div class="entity-section"> | |
<h4>Organizations</h4> | |
<ul>{''.join(f'<li>{org}</li>' for org in analysis_result['entities']['organizations']) or '<li>None detected</li>'}</ul> | |
</div> | |
<div class="entity-section"> | |
<h4>Locations</h4> | |
<ul>{''.join(f'<li>{loc}</li>' for loc in analysis_result['entities']['locations']) or '<li>None detected</li>'}</ul> | |
</div> | |
</div> | |
<div class="space-y-4"> | |
<div class="entity-section"> | |
<h4>Temporal References</h4> | |
<ul>{''.join(f'<li>{time}</li>' for time in analysis_result['entities']['temporal']) or '<li>None detected</li>'}</ul> | |
</div> | |
<div class="entity-section"> | |
<h4>Hashtags</h4> | |
<ul>{''.join(f'<li>{tag}</li>' for tag in analysis_result['entities']['hashtags']) or '<li>None detected</li>'}</ul> | |
</div> | |
{f''' | |
<div class="entity-section"> | |
<h4>Entity Relationships</h4> | |
<ul> | |
{''.join(f""" | |
<li class="mb-2"> | |
<strong>{rel['source']}</strong> → | |
<span class="text-blue-600">{rel['type'].replace('_to_', ' to ')}</span> → | |
<strong>{rel['target']}</strong> | |
<br/> | |
<small class="text-gray-600">Confidence: {int(rel['confidence'] * 100)}%</small> | |
</li> | |
""" for rel in analysis_result['entity_relationships'])} | |
</ul> | |
</div> | |
''' if analysis_result.get('entity_relationships') else ''} | |
</div> | |
</div> | |
{f''' | |
<div class="alert-success mt-4"> | |
✅ <strong>Event Validated:</strong> The extracted information meets confidence thresholds. | |
</div> | |
''' if not analysis_result["verification_needed"] else ''} | |
{f''' | |
<div class="related-events"> | |
<h4>Related Events</h4> | |
<ul> | |
{''.join(f""" | |
<li class="mb-2"> | |
<div class="flex justify-between items-center"> | |
<div>{event["text"]}</div> | |
<div class="text-sm text-gray-600"> | |
{event["timestamp"]} | | |
Confidence: {int(event["confidence"] * 100)}% | |
{f' | Shared Entities: {event["shared_entities"]}' if event.get("shared_entities") else ''} | |
</div> | |
</div> | |
</li> | |
""" for event in analysis_result['related_events'])} | |
</ul> | |
</div> | |
''' if analysis_result.get('related_events') else ''} | |
<div class="entity-stats mt-4 p-4 bg-gray-50 rounded-lg"> | |
<h4 class="mb-2">Analysis Metrics</h4> | |
<div class="grid grid-cols-3 gap-4 text-sm"> | |
<div> | |
<strong>Confidence Breakdown:</strong> | |
<ul class="mt-1"> | |
<li>Base Confidence: {int(analysis_result['confidence'] * 70)}%</li> | |
<li>Entity Boost: {int((analysis_result['confidence'] - 0.7 if analysis_result['confidence'] > 0.7 else 0) * 100)}%</li> | |
</ul> | |
</div> | |
<div> | |
<strong>Entity Coverage:</strong> | |
<ul class="mt-1"> | |
<li>Types Detected: {len([t for t in ['people', 'organizations', 'locations', 'temporal', 'hashtags'] if analysis_result['entities'].get(t)])}</li> | |
<li>Total Entities: {sum(len(e) for e in analysis_result['entities'].values() if isinstance(e, list))}</li> | |
</ul> | |
</div> | |
<div> | |
<strong>Relationships:</strong> | |
<ul class="mt-1"> | |
<li>Direct: {len(analysis_result.get('entity_relationships', []))}</li> | |
<li>Related Events: {len(analysis_result.get('related_events', []))}</li> | |
</ul> | |
</div> | |
</div> | |
</div> | |
</div> | |
""" | |
return html | |
# Modified to properly handle async | |
async def process_input(text): | |
result = await analyzer.analyze_event(text) | |
return format_results(result) | |
demo = gr.Interface( | |
fn=process_input, | |
inputs=[ | |
gr.Textbox( | |
label="Event Text", | |
placeholder="Enter text to analyze (e.g., 'John from Tech Corp. is attending the meeting in Washington, DC tomorrow at 14:30 #tech')", | |
lines=3 | |
) | |
], | |
outputs=gr.HTML(), | |
title="ToY Event Analysis System", | |
description="Analyze text to extract entities, assess confidence, and identify key event information with relationship tracking.", | |
css=css, | |
theme=gr.themes.Soft(), | |
examples=[ | |
["John from Tech Corp. is attending the meeting in Washington, DC tomorrow at 14:30 #tech"], | |
["Sarah Johnson and Mike Smith from Defense Systems Inc. are conducting training in Norfolk, VA on June 15th #defense #training"], | |
["Team meeting at headquarters with @commander_smith at 0900 #briefing"] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |