import os
import logging
from datetime import datetime, timedelta
import json
from collections import defaultdict
import shutil
import re
import argparse
from typing import Dict, Set, Tuple, Optional
from log_reader import RemoteLogReader

# List of IP addresses we care about
WHITELIST_IPS_DICT = {
    "Chen Gong": ["128.143.67.19"],
    "Juyong Jiang": ["175.159.122.63"],
    "Kenneth Hamilton": ["109.245.193.97"],
    "Marek Suppa": ["158.195.18.232"],
    "Max Tian": ["2607:fea8:4f40:4b00:e5b9:9806:6b69:233b", "2607:fea8:4f40:4b00:bcef:571:6124:f01", "2607:fea8:7c9d:3800:d9c0:7295:3e2e:6287"],
    "Mengzhao Jia": ["66.254.231.49"],
    "Noah Ziems": ["2601:245:c500:92c0:633c:c0d2:dcc1:1f48", "2601:245:c500:92c0:961e:9ac7:e02:c266"],
    "Sabina A": ["175.196.44.217", "58.235.174.122", "14.52.175.55"],
    "Wenhao Yu": ["2601:600:8d00:9510:1d77:b610:9358:f443"],
    "Vaisakhi Mishra": ["74.90.222.68"],
    "Kumar Shridhar": ["129.132.145.250"],
    "Viktor Gal": ["2a02:169:3e9:0:6ce8:e76f:faed:c830"],
    "Guangyu Song": ["70.50.179.57"],
    "Bhupesh Bishnoi": ["2a02:842a:24:5a01:8cd6:5b22:1189:6035","192.168.1.8"],
    "Zheng Liu": ["128.143.71.67"],
    "Ming Xu": ["10.0.0.243"],
    "Ayush Sunil Munot": ["10.145.76.56"]
}

# Flatten IP list for backward compatibility
WHITELIST_IPS = [ip for ips in WHITELIST_IPS_DICT.values() for ip in ips]

logging.basicConfig(level=logging.WARNING)
log = logging.getLogger(__name__)

def get_ip_from_jsonl(content: str) -> Optional[str]:
    """Extract IP from the first line of a JSONL content"""
    try:
        first_line = content.split('\n')[0]
        data = json.loads(first_line)
        return data.get('ip')
    except Exception as e:
        log.error(f"Error extracting IP from content: {e}")
        return None

def get_chat_session_id(file_name: str, content: str = None) -> Optional[str]:
    """Extract chat_session_id based on the file location:
    - For files under conv_logs: extract from filename
    - For files under sandbox_logs: read from file content
    """
    try:
        if 'conv_logs' in file_name:
            # Extract from filename for conv_logs
            match = re.match(r'conv-log-([a-f0-9]+)\.json', file_name)
            if match:
                return match.group(1)
        elif 'sandbox_logs' in file_name and content:
            # Read from file content for sandbox_logs
            data = json.loads(content)
            return data['sandbox_state'].get('chat_session_id')
        return None
    except Exception as e:
        log.error(f"Error getting chat_session_id from {file_name}: {e}")
        return None

def get_sandbox_session_ids(reader: 'RemoteLogReader', date_str: str) -> Set[str]:
    """Get all chat_session_ids from sandbox logs for a given date"""
    session_ids = set()
    try:
        sandbox_logs = reader.get_sandbox_logs(date_str)
        for log in sandbox_logs:
            if isinstance(log, dict):
                session_id = log.get('sandbox_state', {}).get('chat_session_id')
                if session_id:
                    session_ids.add(session_id)
    except Exception as e:
        log.error(f"Error getting sandbox session IDs for date {date_str}: {e}")
    
    return session_ids

def get_file_data(content: str) -> Tuple[Optional[str], bool]:
    """Read file content and return IP and vote condition status"""
    try:
        lines = [line.strip() for line in content.split('\n') if line.strip()]
        if not lines:
            return None, False
        
        # Get IP from first line
        try:
            first_line_data = json.loads(lines[0])
            ip = first_line_data.get('ip')
            # Early return if IP is not in whitelist
            if ip not in WHITELIST_IPS:
                return None, False
        except json.JSONDecodeError:
            ip = None
        
        # Check vote conditions from last line
        try:
            last_line_data = json.loads(lines[-1])
            feedback = last_line_data.get('feedback')
            vote_conditions_met = (last_line_data.get('type') == 'vote' and 
                                isinstance(feedback, dict) and 
                                len(feedback) == 6)
        except json.JSONDecodeError:
            vote_conditions_met = False
            
        return ip, vote_conditions_met
    except Exception as e:
        log.error(f"Error processing file content: {e}")
        return None, False

def count_files_per_ip(reader: 'RemoteLogReader', start_date_str: str = "2025_02_18") -> Dict[str, int]:
    """Count files per name from the given start date"""
    # Convert start date string to datetime
    start_date = datetime.strptime(start_date_str, "%Y_%m_%d")
    name_counts = defaultdict(int)
    
    try:
        # Get current date for iteration
        current_date = start_date
        today = datetime.now()
        
        # Create reverse mapping of IP to name
        ip_to_name = {ip: name for name, ips in WHITELIST_IPS_DICT.items() for ip in ips}
        
        while current_date <= today:
            date_str = current_date.strftime("%Y_%m_%d")
            
            try:
                # Get conversation logs for battle_anony mode
                conv_logs = reader.get_conv_logs(date_str)
                battle_anony_logs = conv_logs.get('battle_anony', {})
                
                # Process each conversation
                for conv_id, messages in battle_anony_logs.items():
                    if messages:
                        # Convert messages to file content format
                        content = '\n'.join(json.dumps(msg) for msg in messages)
                        ip, vote_conditions_met = get_file_data(content)
                        if vote_conditions_met and ip and ip in ip_to_name:
                            name = ip_to_name[ip]
                            name_counts[name] += 1
                            
            except Exception as e:
                log.error(f"Error processing logs for date {date_str}: {e}")
            
            # Move to next day
            current_date += timedelta(days=1)
                
    except Exception as e:
        log.error(f"Error accessing logs: {e}")
    
    return dict(name_counts)

def download_files_by_name(reader: 'RemoteLogReader', start_date_str: str = "2025_02_18", check_sandbox: bool = True) -> None:
    """Download files and organize them by annotator name
    
    Args:
        reader: RemoteLogReader instance
        start_date_str: The start date in YYYY_MM_DD format
        check_sandbox: Whether to check for matching sandbox logs
    """
    # Create base data directory
    data_dir = os.path.join(os.getcwd(), "data")
    os.makedirs(data_dir, exist_ok=True)
    
    # Create reverse mapping of IP to name
    ip_to_name = {ip: name for name, ips in WHITELIST_IPS_DICT.items() for ip in ips}
    
    # Convert start date string to datetime
    start_date = datetime.strptime(start_date_str, "%Y_%m_%d")
    
    try:
        # Get current date for iteration
        current_date = start_date
        today = datetime.now()
        
        while current_date <= today:
            date_str = current_date.strftime("%Y_%m_%d")
            
            # Get all sandbox session IDs for this date
            sandbox_session_ids = get_sandbox_session_ids(reader, date_str) if check_sandbox else set()
            
            try:
                # Get conversation logs for battle_anony mode
                conv_logs = reader.get_conv_logs(date_str)
                battle_anony_logs = conv_logs.get('battle_anony', {})
                
                # Process each conversation
                for conv_id, messages in battle_anony_logs.items():
                    if not messages:
                        continue
                        
                    # Convert messages to file content
                    content = '\n'.join(json.dumps(msg) for msg in messages)
                    ip = get_ip_from_jsonl(content)
                    
                    if ip and ip in ip_to_name:
                        name = ip_to_name[ip]
                        # Create directory structure for this name
                        name_dir = os.path.join(data_dir, name)
                        valid_dir = os.path.join(name_dir, "valid")
                        invalid_dir = os.path.join(name_dir, "invalid")
                        os.makedirs(valid_dir, exist_ok=True)
                        os.makedirs(invalid_dir, exist_ok=True)
                        
                        # Check if chat_session_id exists in sandbox logs
                        if check_sandbox:
                            has_sandbox = conv_id in sandbox_session_ids
                            target_dir = valid_dir if has_sandbox else invalid_dir
                        else:
                            # When sandbox checking is disabled, put everything in valid
                            target_dir = valid_dir
                        
                        # Save the file
                        file_name = f"conv-log-{conv_id}.json"
                        local_file_path = os.path.join(target_dir, file_name)
                        try:
                            with open(local_file_path, 'w') as f:
                                f.write(content)
                            log.info(f"Saved {file_name} to {target_dir}")
                        except Exception as e:
                            log.error(f"Error saving file {file_name}: {e}")
            
            except Exception as e:
                log.error(f"Error processing logs for date {date_str}: {e}")
            
            # Move to next day
            current_date += timedelta(days=1)
                
    except Exception as e:
        log.error(f"Error accessing logs: {e}")

def main():
    # Initialize RemoteLogReader
    reader = RemoteLogReader()
    
    # Add argument parser for optional parameters
    parser = argparse.ArgumentParser(description='Download and organize conversation files by annotator name')
    parser.add_argument('--sandbox-check', action='store_true', help='Check for matching sandbox logs')
    parser.add_argument('--download', action='store_true', help='Enable file download')
    args = parser.parse_args()
    
    # Download files if enabled
    if args.download:
        print("\nDownloading files and organizing by annotator name...")
        download_files_by_name(reader, check_sandbox=args.sandbox_check)
    
    # Count and display statistics
    name_counts = count_files_per_ip(reader)
    print("\nFile counts per annotator:")
    for name, count in sorted(name_counts.items(), key=lambda x: x[1], reverse=True):
        print(f"Name: {name:<20} Count: {count}")

if __name__ == "__main__":
    main()