import subprocess
import sys
import os
import traceback

import git
import json
import yaml
import requests
from tqdm.auto import tqdm
from git.remote import RemoteProgress


comfy_path = os.environ.get('COMFYUI_PATH')
git_exe_path = os.environ.get('GIT_EXE_PATH')

if comfy_path is None:
    print("\nWARN: The `COMFYUI_PATH` environment variable is not set. Assuming `custom_nodes/ComfyUI-Manager/../../` as the ComfyUI path.", file=sys.stderr)
    comfy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))


def download_url(url, dest_folder, filename=None):
    # Ensure the destination folder exists
    if not os.path.exists(dest_folder):
        os.makedirs(dest_folder)

    # Extract filename from URL if not provided
    if filename is None:
        filename = os.path.basename(url)

    # Full path to save the file
    dest_path = os.path.join(dest_folder, filename)

    # Download the file
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(dest_path, 'wb') as file:
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:
                    file.write(chunk)
    else:
        print(f"Failed to download file from {url}")


nodelist_path = os.path.join(os.path.dirname(__file__), "custom-node-list.json")
working_directory = os.getcwd()

if os.path.basename(working_directory) != 'custom_nodes':
    print("WARN: This script should be executed in custom_nodes dir")
    print(f"DBG: INFO {working_directory}")
    print(f"DBG: INFO {sys.argv}")
    # exit(-1)


class GitProgress(RemoteProgress):
    def __init__(self):
        super().__init__()
        self.pbar = tqdm(ascii=True)

    def update(self, op_code, cur_count, max_count=None, message=''):
        self.pbar.total = max_count
        self.pbar.n = cur_count
        self.pbar.pos = 0
        self.pbar.refresh()


def gitclone(custom_nodes_path, url, target_hash=None, repo_path=None):
    repo_name = os.path.splitext(os.path.basename(url))[0]

    if repo_path is None:
        repo_path = os.path.join(custom_nodes_path, repo_name)

    # Clone the repository from the remote URL
    repo = git.Repo.clone_from(url, repo_path, recursive=True, progress=GitProgress())

    if target_hash is not None:
        print(f"CHECKOUT: {repo_name} [{target_hash}]")
        repo.git.checkout(target_hash)
            
    repo.git.clear_cache()
    repo.close()


def gitcheck(path, do_fetch=False):
    try:
        # Fetch the latest commits from the remote repository
        repo = git.Repo(path)

        if repo.head.is_detached:
            print("CUSTOM NODE CHECK: True")
            return

        current_branch = repo.active_branch
        branch_name = current_branch.name

        remote_name = current_branch.tracking_branch().remote_name
        remote = repo.remote(name=remote_name)

        if do_fetch:
            remote.fetch()

        # Get the current commit hash and the commit hash of the remote branch
        commit_hash = repo.head.commit.hexsha

        if f'{remote_name}/{branch_name}' in repo.refs:
            remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha
        else:
            print("CUSTOM NODE CHECK: True")  # non default branch is treated as updatable
            return

        # Compare the commit hashes to determine if the local repository is behind the remote repository
        if commit_hash != remote_commit_hash:
            # Get the commit dates
            commit_date = repo.head.commit.committed_datetime
            remote_commit_date = repo.refs[f'{remote_name}/{branch_name}'].object.committed_datetime

            # Compare the commit dates to determine if the local repository is behind the remote repository
            if commit_date < remote_commit_date:
                print("CUSTOM NODE CHECK: True")
        else:
            print("CUSTOM NODE CHECK: False")
    except Exception as e:
        print(e)
        print("CUSTOM NODE CHECK: Error")


def get_remote_name(repo):
    available_remotes = [remote.name for remote in repo.remotes]
    if 'origin' in available_remotes:
        return 'origin'
    elif 'upstream' in available_remotes:
        return 'upstream'
    elif len(available_remotes) > 0:
        return available_remotes[0]

    if not available_remotes:
        print(f"[ComfyUI-Manager] No remotes are configured for this repository: {repo.working_dir}")
    else:
        print(f"[ComfyUI-Manager] Available remotes in '{repo.working_dir}': ")
        for remote in available_remotes:
            print(f"- {remote}")

    return None


def switch_to_default_branch(repo):
    remote_name = get_remote_name(repo)

    try:
        if remote_name is None:
            return False

        default_branch = repo.git.symbolic_ref(f'refs/remotes/{remote_name}/HEAD').replace(f'refs/remotes/{remote_name}/', '')
        repo.git.checkout(default_branch)
        return True
    except:
        try:
            repo.git.checkout(repo.heads.master)
        except:
            try:
                if remote_name is not None:
                    repo.git.checkout('-b', 'master', f'{remote_name}/master')
            except:
                pass

    print("[ComfyUI Manager] Failed to switch to the default branch")
    return False


def gitpull(path):
    # Check if the path is a git repository
    if not os.path.exists(os.path.join(path, '.git')):
        raise ValueError('Not a git repository')

    # Pull the latest changes from the remote repository
    repo = git.Repo(path)
    if repo.is_dirty():
        print(f"STASH: '{path}' is dirty.")
        repo.git.stash()

    commit_hash = repo.head.commit.hexsha
    try:
        if repo.head.is_detached:
            switch_to_default_branch(repo)

        current_branch = repo.active_branch
        branch_name = current_branch.name

        remote_name = current_branch.tracking_branch().remote_name
        remote = repo.remote(name=remote_name)

        if f'{remote_name}/{branch_name}' not in repo.refs:
            switch_to_default_branch(repo)
            current_branch = repo.active_branch
            branch_name = current_branch.name

        remote.fetch()
        if f'{remote_name}/{branch_name}' in repo.refs:
            remote_commit_hash = repo.refs[f'{remote_name}/{branch_name}'].object.hexsha
        else:
            print("CUSTOM NODE PULL: Fail")  # update fail
            return

        if commit_hash == remote_commit_hash:
            print("CUSTOM NODE PULL: None")  # there is no update
            repo.close()
            return

        remote.pull()

        repo.git.submodule('update', '--init', '--recursive')
        new_commit_hash = repo.head.commit.hexsha

        if commit_hash != new_commit_hash:
            print("CUSTOM NODE PULL: Success")  # update success
        else:
            print("CUSTOM NODE PULL: Fail")  # update fail
    except Exception as e:
        print(e)
        print("CUSTOM NODE PULL: Fail")  # unknown git error

    repo.close()


def checkout_comfyui_hash(target_hash):
    repo = git.Repo(comfy_path)
    commit_hash = repo.head.commit.hexsha

    if commit_hash != target_hash:
        try:
            print(f"CHECKOUT: ComfyUI [{target_hash}]")
            repo.git.checkout(target_hash)
        except git.GitCommandError as e:
            print(f"Error checking out the ComfyUI: {str(e)}")


def checkout_custom_node_hash(git_custom_node_infos):
    repo_name_to_url = {}

    for url in git_custom_node_infos.keys():
        repo_name = url.split('/')[-1]

        if repo_name.endswith('.git'):
            repo_name = repo_name[:-4]

        repo_name_to_url[repo_name] = url

    for path in os.listdir(working_directory):
        if path.endswith("ComfyUI-Manager"):
            continue

        fullpath = os.path.join(working_directory, path)

        if os.path.isdir(fullpath):
            is_disabled = path.endswith(".disabled")

            try:
                git_dir = os.path.join(fullpath, '.git')
                if not os.path.exists(git_dir):
                    continue

                need_checkout = False
                repo_name = os.path.basename(fullpath)

                if repo_name.endswith('.disabled'):
                    repo_name = repo_name[:-9]

                if repo_name not in repo_name_to_url:
                    if not is_disabled:
                        # should be disabled
                        print(f"DISABLE: {repo_name}")
                        new_path = fullpath + ".disabled"
                        os.rename(fullpath, new_path)
                        need_checkout = False
                else:
                    item = git_custom_node_infos[repo_name_to_url[repo_name]]
                    if item['disabled'] and is_disabled:
                        pass
                    elif item['disabled'] and not is_disabled:
                        # disable
                        print(f"DISABLE: {repo_name}")
                        new_path = fullpath + ".disabled"
                        os.rename(fullpath, new_path)

                    elif not item['disabled'] and is_disabled:
                        # enable
                        print(f"ENABLE: {repo_name}")
                        new_path = fullpath[:-9]
                        os.rename(fullpath, new_path)
                        fullpath = new_path
                        need_checkout = True
                    else:
                        need_checkout = True

                if need_checkout:
                    repo = git.Repo(fullpath)
                    commit_hash = repo.head.commit.hexsha

                    if commit_hash != item['hash']:
                        print(f"CHECKOUT: {repo_name} [{item['hash']}]")
                        repo.git.checkout(item['hash'])

            except Exception:
                print(f"Failed to restore snapshots for the custom node '{path}'")

    # clone missing
    for k, v in git_custom_node_infos.items():
        if 'ComfyUI-Manager' in k:
            continue

        if not v['disabled']:
            repo_name = k.split('/')[-1]
            if repo_name.endswith('.git'):
                repo_name = repo_name[:-4]

            path = os.path.join(working_directory, repo_name)
            if not os.path.exists(path):
                print(f"CLONE: {path}")
                gitclone(working_directory, k, target_hash=v['hash'])


def invalidate_custom_node_file(file_custom_node_infos):
    global nodelist_path

    enabled_set = set()
    for item in file_custom_node_infos:
        if not item['disabled']:
            enabled_set.add(item['filename'])

    for path in os.listdir(working_directory):
        fullpath = os.path.join(working_directory, path)

        if not os.path.isdir(fullpath) and fullpath.endswith('.py'):
            if path not in enabled_set:
                print(f"DISABLE: {path}")
                new_path = fullpath+'.disabled'
                os.rename(fullpath, new_path)

        elif not os.path.isdir(fullpath) and fullpath.endswith('.py.disabled'):
            path = path[:-9]
            if path in enabled_set:
                print(f"ENABLE: {path}")
                new_path = fullpath[:-9]
                os.rename(fullpath, new_path)

    # download missing: just support for 'copy' style
    py_to_url = {}

    with open(nodelist_path, 'r', encoding="UTF-8") as json_file:
        info = json.load(json_file)
        for item in info['custom_nodes']:
            if item['install_type'] == 'copy':
                for url in item['files']:
                    if url.endswith('.py'):
                        py = url.split('/')[-1]
                        py_to_url[py] = url

        for item in file_custom_node_infos:
            filename = item['filename']
            if not item['disabled']:
                target_path = os.path.join(working_directory, filename)

                if not os.path.exists(target_path) and filename in py_to_url:
                    url = py_to_url[filename]
                    print(f"DOWNLOAD: {filename}")
                    download_url(url, working_directory)


def apply_snapshot(path):
    try:
        if os.path.exists(path):
            if not path.endswith('.json') and not path.endswith('.yaml'):
                print(f"Snapshot file not found: `{path}`")
                print("APPLY SNAPSHOT: False")
                return None

            with open(path, 'r', encoding="UTF-8") as snapshot_file:
                if path.endswith('.json'):
                    info = json.load(snapshot_file)
                elif path.endswith('.yaml'):
                    info = yaml.load(snapshot_file, Loader=yaml.SafeLoader)
                    info = info['custom_nodes']
                else:
                    # impossible case
                    print("APPLY SNAPSHOT: False")
                    return None

                comfyui_hash = info['comfyui']
                git_custom_node_infos = info['git_custom_nodes']
                file_custom_node_infos = info['file_custom_nodes']

                checkout_comfyui_hash(comfyui_hash)
                checkout_custom_node_hash(git_custom_node_infos)
                invalidate_custom_node_file(file_custom_node_infos)

                print("APPLY SNAPSHOT: True")
                if 'pips' in info:
                    return info['pips']
                else:
                    return None

        print(f"Snapshot file not found: `{path}`")
        print("APPLY SNAPSHOT: False")

        return None
    except Exception as e:
        print(e)
        traceback.print_exc()
        print("APPLY SNAPSHOT: False")

        return None


def restore_pip_snapshot(pips, options):
    non_url = []
    local_url = []
    non_local_url = []
    for k, v in pips.items():
        if v == "":
            non_url.append(k)
        else:
            if v.startswith('file:'):
                local_url.append(v)
            else:
                non_local_url.append(v)

    failed = []
    if '--pip-non-url' in options:
        # try all at once
        res = 1
        try:
            res = subprocess.check_call([sys.executable, '-m', 'pip', 'install'] + non_url)
        except:
            pass

        # fallback
        if res != 0:
            for x in non_url:
                res = 1
                try:
                    res = subprocess.check_call([sys.executable, '-m', 'pip', 'install', x])
                except:
                    pass

                if res != 0:
                    failed.append(x)

    if '--pip-non-local-url' in options:
        for x in non_local_url:
            res = 1
            try:
                res = subprocess.check_call([sys.executable, '-m', 'pip', 'install', x])
            except:
                pass

            if res != 0:
                failed.append(x)

    if '--pip-local-url' in options:
        for x in local_url:
            res = 1
            try:
                res = subprocess.check_call([sys.executable, '-m', 'pip', 'install', x])
            except:
                pass

            if res != 0:
                failed.append(x)

    print(f"Installation failed for pip packages: {failed}")


def setup_environment():
    if git_exe_path is not None:
        git.Git().update_environment(GIT_PYTHON_GIT_EXECUTABLE=git_exe_path)


setup_environment()


try:
    if sys.argv[1] == "--clone":
        repo_path = None
        if len(sys.argv) > 4:
            repo_path = sys.argv[4]

        gitclone(sys.argv[2], sys.argv[3], repo_path=repo_path)
    elif sys.argv[1] == "--check":
        gitcheck(sys.argv[2], False)
    elif sys.argv[1] == "--fetch":
        gitcheck(sys.argv[2], True)
    elif sys.argv[1] == "--pull":
        gitpull(sys.argv[2])
    elif sys.argv[1] == "--apply-snapshot":
        options = set()
        for x in sys.argv:
            if x in ['--pip-non-url', '--pip-local-url', '--pip-non-local-url']:
                options.add(x)

        pips = apply_snapshot(sys.argv[2])

        if pips and len(options) > 0:
            restore_pip_snapshot(pips, options)
    sys.exit(0)
except Exception as e:
    print(e)
    sys.exit(-1)