import gradio as gr
from transformers import pipeline
import json
from datetime import datetime
import sqlite3
import asyncio
from concurrent.futures import ThreadPoolExecutor
import re

# Initialize NLP pipelines
ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
classifier = pipeline("zero-shot-classification")

class OntologyRegistry:
    def __init__(self):
        self.temporal_patterns = [
            r'\b\d{1,2}:\d{2}\s*(?:AM|PM|am|pm)?\b',
            r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]* \d{1,2}(?:st|nd|rd|th)?,? \d{4}\b',
            r'\btomorrow\b',
            r'\bin \d+ (?:days?|weeks?|months?)\b'
        ]
        
        self.location_patterns = [
            r'\b(?:in|at|from|to) ([A-Z][a-zA-Z]+(,? [A-Z]{2})?)\b',
            r'\b[A-Z][a-zA-Z]+ Base\b',
            r'\bHeadquarters\b',
            r'\bHQ\b'
        ]
        
        self.entity_types = {
            'PER': 'person',
            'ORG': 'organization',
            'LOC': 'location',
            'MISC': 'miscellaneous'
        }

    def validate_pattern(self, text, pattern_type):
        patterns = getattr(self, f"{pattern_type}_patterns", [])
        matches = []
        for pattern in patterns:
            matches.extend(re.finditer(pattern, text))
        return [m.group() for m in matches]

class RelationshipEngine:
    def __init__(self, db_path=':memory:'):
        self.conn = sqlite3.connect(db_path)
        self.setup_database()

    def setup_database(self):
        self.conn.execute('''
            CREATE TABLE IF NOT EXISTS events (
                id INTEGER PRIMARY KEY,
                text TEXT,
                timestamp DATETIME,
                confidence REAL
            )
        ''')
        
        self.conn.execute('''
            CREATE TABLE IF NOT EXISTS relationships (
                id INTEGER PRIMARY KEY,
                source_event_id INTEGER,
                target_event_id INTEGER,
                relationship_type TEXT,
                confidence REAL,
                FOREIGN KEY (source_event_id) REFERENCES events(id),
                FOREIGN KEY (target_event_id) REFERENCES events(id)
            )
        ''')
        self.conn.commit()

    def find_related_events(self, event_data):
        # Find events with similar entities
        cursor = self.conn.execute('''
            SELECT * FROM events 
            WHERE text LIKE ? 
            ORDER BY timestamp DESC 
            LIMIT 5
        ''', (f"%{event_data.get('text', '')}%",))
        
        related_events = cursor.fetchall()
        return related_events

    def calculate_relationship_confidence(self, event1, event2):
        # Simple similarity-based confidence
        base_confidence = 0.0
        
        # Entity overlap increases confidence
        if set(event1.get('entities', {}).get('people', [])) & set(event2.get('entities', {}).get('people', [])):
            base_confidence += 0.3
            
        if set(event1.get('entities', {}).get('organizations', [])) & set(event2.get('entities', {}).get('organizations', [])):
            base_confidence += 0.3
            
        if set(event1.get('entities', {}).get('locations', [])) & set(event2.get('entities', {}).get('locations', [])):
            base_confidence += 0.4
            
        return min(base_confidence, 1.0)

class EventAnalyzer:
    def __init__(self):
        self.ontology = OntologyRegistry()
        self.relationship_engine = RelationshipEngine()
        self.executor = ThreadPoolExecutor(max_workers=3)

    async def extract_entities(self, text):
        def _extract():
            return ner_pipeline(text)
        
        # Run NER in thread pool
        ner_results = await asyncio.get_event_loop().run_in_executor(
            self.executor, _extract
        )
        
        entities = {
            "people": [],
            "organizations": [],
            "locations": [],
            "hashtags": [word for word in text.split() if word.startswith('#')]
        }
        
        for item in ner_results:
            if item["entity"].endswith("PER"):
                entities["people"].append(item["word"])
            elif item["entity"].endswith("ORG"):
                entities["organizations"].append(item["word"])
            elif item["entity"].endswith("LOC"):
                entities["locations"].append(item["word"])
                
        return entities

    async def extract_temporal(self, text):
        return self.ontology.validate_pattern(text, 'temporal')

    async def extract_locations(self, text):
        ml_locations = [loc for loc in await self.extract_entities(text).get('locations', [])]
        pattern_locations = self.ontology.validate_pattern(text, 'location')
        return list(set(ml_locations + pattern_locations))

    async def analyze_event(self, text):
        try:
            # Parallel extraction
            entities_task = self.extract_entities(text)
            temporal_task = self.extract_temporal(text)
            locations_task = self.extract_locations(text)
            
            # Gather results
            entities, temporal, locations = await asyncio.gather(
                entities_task, temporal_task, locations_task
            )
            
            # Merge location results
            entities['locations'] = locations
            entities['temporal'] = temporal
            
            # Calculate initial confidence
            confidence = min(1.0, (
                0.2 * bool(entities["people"]) +
                0.2 * bool(entities["organizations"]) +
                0.3 * bool(entities["locations"]) +
                0.3 * bool(temporal)
            ))
            
            # Find related events
            related_events = self.relationship_engine.find_related_events({
                'text': text,
                'entities': entities
            })
            
            # Adjust confidence based on relationships
            if related_events:
                relationship_confidence = max(
                    self.relationship_engine.calculate_relationship_confidence(
                        {'entities': entities}, 
                        {'text': event[1]} # event[1] is the text field
                    )
                    for event in related_events
                )
                confidence = (confidence + relationship_confidence) / 2
            
            result = {
                "text": text,
                "entities": entities,
                "confidence": confidence,
                "verification_needed": confidence < 0.6,
                "related_events": [
                    {
                        "text": event[1],
                        "timestamp": event[2],
                        "confidence": event[3]
                    }
                    for event in related_events
                ]
            }
            
            # Store event if confidence is sufficient
            if confidence >= 0.6:
                self.relationship_engine.conn.execute(
                    'INSERT INTO events (text, timestamp, confidence) VALUES (?, ?, ?)',
                    (text, datetime.now().isoformat(), confidence)
                )
                self.relationship_engine.conn.commit()
            
            return result
            
        except Exception as e:
            return {"error": str(e)}

# Initialize analyzer
analyzer = EventAnalyzer()

# Custom CSS for UI
css = """
.container { max-width: 1200px; margin: auto; padding: 20px; }
.results { padding: 20px; border: 1px solid #ddd; border-radius: 8px; margin-top: 20px; }
.confidence-high { color: #22c55e; font-weight: bold; }
.confidence-low { color: #f97316; font-weight: bold; }
.entity-section { margin: 15px 0; }
.alert-warning { background: #fff3cd; padding: 10px; border-radius: 5px; margin: 10px 0; }
.alert-success { background: #d1fae5; padding: 10px; border-radius: 5px; margin: 10px 0; }
.related-events { background: #f3f4f6; padding: 15px; border-radius: 5px; margin-top: 15px; }
"""

def format_results(analysis_result):
    if "error" in analysis_result:
        return f"<div style='color: red'>Error: {analysis_result['error']}</div>"
    
    confidence_class = "confidence-high" if analysis_result["confidence"] >= 0.6 else "confidence-low"
    
    html = f"""
    <div class="results">
        <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 20px;">
            <h3 style="margin: 0;">Analysis Results</h3>
            <div>
                Confidence Score: <span class="{confidence_class}">{int(analysis_result['confidence'] * 100)}%</span>
            </div>
        </div>
        
        {f'''
        <div class="alert-warning">
            ⚠️ <strong>Verification Required:</strong> Low confidence score detected. Please verify the extracted information.
        </div>
        ''' if analysis_result["verification_needed"] else ''}
        
        <div class="entity-section">
            <h4>👤 People Detected</h4>
            <ul>{''.join(f'<li>{person}</li>' for person in analysis_result['entities']['people']) or '<li>None detected</li>'}</ul>
        </div>
        
        <div class="entity-section">
            <h4>🏢 Organizations</h4>
            <ul>{''.join(f'<li>{org}</li>' for org in analysis_result['entities']['organizations']) or '<li>None detected</li>'}</ul>
        </div>
        
        <div class="entity-section">
            <h4>📍 Locations</h4>
            <ul>{''.join(f'<li>{loc}</li>' for loc in analysis_result['entities']['locations']) or '<li>None detected</li>'}</ul>
        </div>
        
        <div class="entity-section">
            <h4>🕒 Temporal References</h4>
            <ul>{''.join(f'<li>{time}</li>' for time in analysis_result['entities']['temporal']) or '<li>None detected</li>'}</ul>
        </div>
        
        <div class="entity-section">
            <h4># Hashtags</h4>
            <ul>{''.join(f'<li>{tag}</li>' for tag in analysis_result['entities']['hashtags']) or '<li>None detected</li>'}</ul>
        </div>
        
        {f'''
        <div class="alert-success">
            ✅ <strong>Event Validated:</strong> The extracted information meets confidence thresholds.
        </div>
        ''' if not analysis_result["verification_needed"] else ''}
        
        {f'''
        <div class="related-events">
            <h4>Related Events</h4>
            <ul>
                {''.join(f'<li>{event["text"]} ({event["timestamp"]}) - Confidence: {int(event["confidence"] * 100)}%</li>' for event in analysis_result['related_events'])}
            </ul>
        </div>
        ''' if analysis_result.get('related_events') else ''}
    </div>
    """
    return html

async def process_input(text):
    result = await analyzer.analyze_event(text)
    return format_results(result)

demo = gr.Interface(
    fn=process_input,
    inputs=[
        gr.Textbox(
            label="Event Text",
            placeholder="Enter text to analyze (e.g., 'John from Tech Corp. is attending the meeting in Washington, DC tomorrow at 14:30 #tech')",
            lines=3
        )
    ],
    outputs=gr.HTML(),
    title="DoD Event Analysis System",
    description="Analyze text to extract entities, assess confidence, and identify key event information with relationship tracking.",
    css=css,
    theme=gr.themes.Soft(),
    examples=[
        ["John from Tech Corp. is attending the meeting in Washington, DC tomorrow at 14:30 #tech"],
        ["Sarah Johnson and Mike Smith from Defense Systems Inc. are conducting training in Norfolk, VA on June 15th #defense #training"],
        ["Team meeting at headquarters with @commander_smith at 0900 #briefing"]
    ]
)

if __name__ == "__main__":
    demo.launch()