import fire
import subprocess
import os
import time
import signal
import subprocess
import atexit


def kill_processes_by_cmd_substring(cmd_substring):
    # execute `ps -ef` and obtain its output
    result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
    lines = result.stdout.splitlines()

    # visit each line
    for line in lines:
        if cmd_substring in line:
            # extract PID
            parts = line.split()
            pid = int(parts[1])
            print(f"Killing process with PID: {pid}, CMD: {line}")
            os.kill(pid, signal.SIGTERM)


def main(
    python_path="python",
    run_controller=True,
    run_worker=True,
    run_gradio=True,
    controller_port=10086,
    gradio_port=7860,
    worker_names=[
        "OpenGVLab/InternVL2-8B",
    ],
    run_sd_worker=False,
    **kwargs,
):
    host = "http://0.0.0.0"
    controller_process = None
    if run_controller:
        # python controller.py --host 0.0.0.0 --port 10086
        cmd_args = [
            f"{python_path}",
            "controller.py",
            "--host",
            "0.0.0.0",
            "--port",
            f"{controller_port}",
        ]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching controller: ", " ".join(cmd_args))
        controller_process = subprocess.Popen(cmd_args)
        atexit.register(controller_process.terminate)

    worker_processes = []
    if run_worker:
        worker_port = 10088
        for worker_name in worker_names:
            cmd_args = [
                f"{python_path}",
                "model_worker.py",
                "--port",
                f"{worker_port}",
                "--controller-url",
                f"{host}:{controller_port}",
                "--model-path",
                f"{worker_name}",
                "--load-8bit",
            ]
            kill_processes_by_cmd_substring(" ".join(cmd_args))
            print("Launching worker: ", " ".join(cmd_args))
            worker_process = subprocess.Popen(cmd_args)
            worker_processes.append(worker_process)
            atexit.register(worker_process.terminate)
            worker_port += 1

    time.sleep(10)
    gradio_process = None
    if run_gradio:
        #  python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
        cmd_args = [
            f"{python_path}",
            "gradio_web_server.py",
            "--port",
            f"{gradio_port}",
            "--controller-url",
            f"{host}:{controller_port}",
            "--model-list-mode",
            "reload",
        ]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching gradio: ", " ".join(cmd_args))
        gradio_process = subprocess.Popen(cmd_args)
        atexit.register(gradio_process.terminate)

    sd_worker_process = None
    if run_sd_worker:
        # python model_worker.py --port 10088 --controller-address http://
        cmd_args = [f"{python_path}", "sd_worker.py"]
        kill_processes_by_cmd_substring(" ".join(cmd_args))
        print("Launching sd_worker: ", " ".join(cmd_args))
        sd_worker_process = subprocess.Popen(cmd_args)
        atexit.register(sd_worker_process.terminate)

    for worker_process in worker_processes:
        worker_process.wait()
    if controller_process:
        controller_process.wait()
    if gradio_process:
        gradio_process.wait()
    if sd_worker_process:
        sd_worker_process.wait()


if __name__ == "__main__":
    fire.Fire(main)