#from ..mew_core import prettify
from ..mew_log.ansi_utils import ansi_color_str
from ..mew_log.attr_utils import get_caller_info, get_var_value
from ..mew_log.table_utils import format_table

import sys
import types 
import objgraph

one_mb = 2**20

def bytes_to_mb(size_in_bytes, inv_bytes_in_mb=(1.0 / (1024 * 1024))):
    """
    Convert size from bytes to megabytes.
    """
    return size_in_bytes * inv_bytes_in_mb


def bytes_to_kb(size_in_bytes, inv_bytes_in_kb=(1.0 / (1024))):
    """
    Convert size from bytes to kilobytes.
    """
    return size_in_bytes * inv_bytes_in_kb

def format_mem_size_value(size_bytes, use_color=True):
    size_b = size_bytes
    size_mb = bytes_to_mb(size_b)
    size_kb = bytes_to_kb(size_b)
    
    size_mb_str = f"{size_mb:.4f} MB"
    size_kb_str = f"{size_kb:.4f} KB"
    size_b_str = f"{size_b} B"

    # Determine which size to display based on the criteria
    if size_mb > 0.1:  # More than 0.1 MB
        display_str = size_mb_str
    elif size_kb > 0.1:  # More than 0.1 KB but less than 0.1 MB
        display_str = size_kb_str
    else:  # Less than 0.1 KB
        display_str = size_b_str

    # Apply color if use_color is True
    if use_color:
        display_str = ansi_color_str(display_str, fg='cyan')

    return f"Mem Size: {display_str}"  # Display the selected size


def format_size_bytes(obj, use_color=True):
    size_b, size_mb = get_mem_size(obj)
    size_kb = bytes_to_kb(size_b)
    size_mb_str = f"{size_mb:.4f} MB"
    size_kb_str = f"{size_kb:.4f} KB"
    size_b_str = f"{size_b} B"

    # Determine which size to display based on the criteria
    if size_mb > 0.1:  # More than 0.1 MB
        display_str = size_mb_str
    elif size_kb > 0.1:  # More than 0.1 KB but less than 0.1 MB
        display_str = size_kb_str
    else:  # Less than 0.1 KB
        display_str = size_b_str

    # Apply color if use_color is True
    if use_color:
        display_str = ansi_color_str(display_str, fg='cyan')

    return f"Mem Size: {display_str}"  # Display the selected size



def get_mem_size(obj) -> tuple:
    """
    Get the size of an object as a tuple ( bytes , megabytes )
    """
    def _get_mem_size(obj):
        size = sys.getsizeof(obj)
        
        if isinstance(obj, types.GeneratorType):
            # Generators don't have __sizeof__, so calculate size recursively
            size += sum(_get_mem_size(item) for item in obj)
        elif isinstance(obj, dict):
            size += sum(_get_mem_size(key) + _get_mem_size(value) for key, value in obj.items())
        elif hasattr(obj, '__dict__'):
            # For objects with __dict__, include the size of their attributes
            size += _get_mem_size(obj.__dict__)
        
        return size

    size_b = _get_mem_size(obj)
    size_mb = bytes_to_mb(size_b)
    
    return size_b, size_mb

def get_mem_size_breakdown(obj, do_recursive=False) -> dict:
    """
    Get a breakdown of the sizes of the properties of an object.
    """

    size_b, size_mb = get_mem_size(obj) #'name': obj, 'value': get_variable_value(obj)
    breakdown = { 'name': obj,'type': type(obj).__name__, 'size (B, MB)': f"{size_b} B | {size_mb:.3f}  MB", 'value': get_var_value(obj), }
    
    if do_recursive:
        if isinstance(obj, types.GeneratorType):
            for i, item in enumerate(obj):
                breakdown['generator_'+type(item).__name__][i] = get_mem_size_breakdown(item)
        elif isinstance(obj, dict):
            for key, value in obj.items():
                if key != 'stat':
                    breakdown[f'[{key}]'] = get_mem_size_breakdown(value)
        elif hasattr(obj, '__dict__'):
            breakdown['vars'] = {}
            for key, value in obj.__dict__.items():
                if key != 'stat':
                    breakdown['vars'][f'[{key}]'] = get_mem_size_breakdown(value)
    return breakdown

def save_obj_graph_img(obj):
    # Generate and save a graph of objects referencing a specific object
    graph_filename = f'data/image/obj_graph_{obj.__class__.__name__}.png'
    objgraph.show_refs([obj], filename=graph_filename)
    return graph_filename

def take_mem_growth_snapshot():
    # Take a snapshot of the objects that have grown since the last snapshot
    growth_info = objgraph.show_growth()
    if growth_info:
        # Convert the growth information to a dictionary
        growth_dict = {}
        for line in growth_info.split('\n'):
            if line.strip():  # Skip empty lines
                parts = line.split()
                obj_type = parts[0]
                prev_count = int(parts[1])
                growth_count = int(parts[2])
                growth_dict[obj_type] = {'prev_count': prev_count, 'growth_count': growth_count}
        #print(format_table(growth_dict))
        return growth_dict