import time

from .ansi_utils import ansi_color_str
from .attr_utils import get_prev_caller_info
from .log_utils import log_info, LogSpinner
from .table_utils import format_table


time_color = ''

# helper funcs
class ProfileSection:
    Registry = {}  # Class Variable

    def __init__(self, newMsg=None, enable=True, do_print=False):
        self.msg = newMsg if newMsg else get_prev_caller_info()
        self.enable = enable
        self.startTime = 0.0
        self.endTime = 0.0
        self.elapsedTime = 0.0
        self.totalTime = 0.0
        self.avgTime = 0.0
        self.numCalls = 0
        self.do_print = do_print
        if self.msg in ProfileSection.Registry:
            p = ProfileSection.Registry[self.msg]
            self.__dict__ = p.__dict__
            self.numCalls += 1
        else:
            self.numCalls += 1

        if (self.enable):
            self.start()

    def __del__(self):
        if (self.enable):
            self.stop()

    def __enter__(self):
        # if self.enable:
        self.start()

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        # if self.enable:
        self.stop()

    def start(self):
        self.startTime = time.time()
        # if self.do_print : log_info(str(self))

    def stop(self):
        self.endTime = time.time()
        self.update_stats()
        stopMsg = str(self)

        p = create_or_get_profile(self.msg)
        p.__dict__.update(self.__dict__)
        if self.do_print: log_info(str(p))

        # st.toast(stopMsg)
        return stopMsg

    def update_stats(self):
        self.elapsedTime = self.endTime - self.startTime
        self.totalTime += self.elapsedTime
        self.avgTime = self.totalTime / float(self.numCalls)

    def __str__(self, use_color=True):
        msg_str = self.msg #ansi_color_str(self.msg, fg='green')  # Green color for message
        elapsed_time_str = make_time_str('elapsed', self.elapsedTime)
        total_time_str = make_time_str('total', self.totalTime)
        avg_time_str = make_time_str('avg', self.avgTime)
        calls_str = f'calls={self.numCalls}'
        if use_color:
            elapsed_time_str = ansi_color_str(elapsed_time_str, fg='bright_cyan')  # Cyan color for elapsed time
            total_time_str = ansi_color_str(total_time_str, fg='bright_cyan')  # Cyan color for total time
            avg_time_str = ansi_color_str(avg_time_str, fg='bright_cyan')  # Cyan color for average time
            calls_str = ansi_color_str(calls_str, fg='yellow')  # Cyan color for calls information

        return f"{msg_str} ~ Elapsed Time: {elapsed_time_str} | Total Time: {total_time_str} | Avg Time: {avg_time_str} | {calls_str}"


from functools import wraps
from threading import Thread
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, TaskProgressColumn
from rich.console import Console
from rich.style import Style
from rich.panel import Panel


def profile_function(func):
    """
    Decorator to track the progress of a function with a spinner.
    The decorated function should not require explicit progress updates.
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        profile_section = ProfileSection(func.__name__, enable=True, do_print=False)
        console = Console()
        with LogSpinner(func.__name__):
            with profile_section:
                result = func(*args, **kwargs)

        #profile_section.stop()
        # Display timing information using rich
        panel = Panel.fit(
            f"[bold green]{profile_section.msg}[/bold green]\n"
            f"[cyan]Elapsed Time:[/cyan] {make_time_str('elapsed', profile_section.elapsedTime)}\n"
            f"[cyan]Total Time:[/cyan] {make_time_str('total', profile_section.totalTime)}\n"
            f"[cyan]Avg Time:[/cyan] {make_time_str('avg', profile_section.avgTime)}\n"
            f"[yellow]Calls:[/yellow] {profile_section.numCalls}",
            title="Profile Report",
            border_style="bright_blue"
        )
        console.print(panel)
        #print(str(profile_section))


        return result
    return wrapper


def make_time_str(msg, value):
    # do something fancy
    value, time_unit = (value / 60, 'min') if value >= 60 else (value * 1000, 'ms') if value < 0.01 else (value, 's')
    return f"{msg}={int(value) if value % 1 == 0 else value:.2f} {time_unit}"


def create_or_get_profile(key, enable=False, do_print=False):
    if key not in ProfileSection.Registry:
        ProfileSection.Registry[key] = ProfileSection(key, enable, do_print)
    return ProfileSection.Registry[key]


def profile_start(msg, enable=True, do_print=False):
    p = create_or_get_profile(msg, enable, do_print)
    if not enable: p.start()


def profile_stop(msg):
    if key in ProfileSection.Registry:
        create_or_get_profile(msg).stop()


def get_profile_registry():
    return ProfileSection.Registry

from loguru import logger


def get_profile_reports():
    reports = [value for value in ProfileSection.Registry.values()]
    reports.sort(key=lambda x: (x.totalTime, x.avgTime), reverse=True)
    return reports
    
def log_profile_registry(use_color=True):
    formatted_output = format_profile_registry(use_color=use_color)
    print(formatted_output)
    #logger.info(formatted_output)
    #return formatted_output



def allow_curly_braces(original_string):
    escaped_string = original_string.replace("{", "{{").replace("}", "}}")
    #print("Escaped String:", escaped_string)  # Debug output
    return escaped_string

def format_profile_registry(use_color=True):
    reports = get_profile_reports()
    out_str = []

    out_str.append('=== Profile Reports ===\n')

    for report in reports:
        out_str.append(str(report)+'\n')

    out_str.append('===>_<===\n')
    return ''.join(out_str)


# import random

# def do_this(y):
#     p = ProfileSection("do_this", True) #only way to use the auto destruct method
#     x = random.randint(0, (y+1)*2)
#     print(f'do_this: {x} - {y}')

# for index in range(1000):
#     do_this(index)

# log_profile_registry()