from flask import Flask, request, jsonify
import uuid
import time
import docker
import requests
import atexit
import socket 
import argparse
import logging
from pydantic import BaseModel, Field, ValidationError


app = Flask(__name__)
app.logger.setLevel(logging.INFO)


# CLI function to parse arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Jupyter server.")
    parser.add_argument('--n_instances', type=int, help="Number of Jupyter instances.")
    parser.add_argument('--n_cpus', type=int, default=2, help="Number of CPUs per Jupyter instance.")
    parser.add_argument('--mem', type=str, default="2g", help="Amount of memory per Jupyter instance.")
    parser.add_argument('--execution_timeout', type=int, default=10, help="Timeout period for a code execution.")
    parser.add_argument('--port', type=int, default=5001, help="Port of main server")
    return parser.parse_args()


def get_unused_port(start=50000, end=65535, exclusion=[]):
    for port in range(start, end + 1):
        if port in exclusion:
            continue
        try:
            sock = socket.socket()
            sock.bind(("", port))
            sock.listen(1)
            sock.close()
            return port
        except OSError:
            continue
    raise IOError("No free ports available in range {}-{}".format(start, end))


def create_kernel_containers(n_instances, n_cpus=2, mem="2g", execution_timeout=10):

    docker_client = docker.from_env()
    app.logger.info("Buidling docker image...")
    image, logs = docker_client.images.build(path='./', tag='jupyter-kernel:latest')
    app.logger.info("Building docker image complete.")

    containers = []
    port_exclusion = []
    for i in range(n_instances):
        
        free_port =  get_unused_port(exclusion=port_exclusion)
        port_exclusion.append(free_port) # it takes a while to startup so we don't use the same port twice
        app.logger.info(f"Starting container {i} on port {free_port}...")
        container = docker_client.containers.run(
            "jupyter-kernel:latest",
            detach=True,
            mem_limit=mem,
            cpuset_cpus=f"{i*n_cpus}-{(i+1)*n_cpus-1}",  # Limit to CPU cores 0 and 1
            remove=True,
            ports={'5000/tcp': free_port},
            environment={"EXECUTION_TIMEOUT": execution_timeout},
        )

        containers.append({"container": container, "port": free_port})

    start_time = time.time()
    
    containers_ready = []

    while len(containers_ready) < n_instances:
        app.logger.info("Pinging Jupyter containers to check readiness.")
        if time.time() - start_time > 60:
            raise TimeoutError("Container took too long to startup.")
        for i in range(n_instances):
            if i in containers_ready:
                continue
            url = f"http://localhost:{containers[i]['port']}/health"
            try:
                # TODO: dedicated health endpoint
                response = requests.get(url)
                if response.status_code == 200:
                    containers_ready.append(i)
            except Exception as e:
                # Catch any other errors that might occur
                pass
        time.sleep(0.5)
    app.logger.info("Containers ready!")
    return containers

def shutdown_cleanup():
    app.logger.info("Shutting down. Stopping and removing all containers...")
    for instance in app.containers:
        try:
            instance['container'].stop()
            instance['container'].remove()
        except Exception as e:
            app.logger.info(f"Error stopping/removing container: {str(e)}")
    app.logger.info("All containers stopped and removed.")


class ServerRequest(BaseModel):
    code: str = Field(..., example="print('Hello World!')")
    instance_id: int = Field(0, example=0)
    restart: bool = Field(False, example=False)


@app.route('/execute', methods=['POST'])
def execute_code():
    try:
        input = ServerRequest(**request.json)
    except ValidationError as e:
        return jsonify(e.errors()), 400


    port = app.containers[input.instance_id]["port"]

    app.logger.info(f"Received request for instance {input.instance_id} (port={port}).")

    try:
        if input.restart:
            response = requests.post(f'http://localhost:{port}/restart', json={})
            if response.status_code==200:
                app.logger.info(f"Kernel for instance {input.instance_id} restarted.")
            else:
                app.logger.info(f"Error when restarting kernel of instance {input.instance_id}: {response.json()}.")

        response = requests.post(f'http://localhost:{port}/execute', json={'code': input.code})
        result = response.json()
        return result
    
    except Exception as e:
        app.logger.info(f"Error in execute_code: {str(e)}")
        return jsonify({
            'result': 'error',
            'output': str(e)
        }), 500


atexit.register(shutdown_cleanup)

if __name__ == '__main__':
    args = parse_args()
    app.containers = create_kernel_containers(
        args.n_instances,
        n_cpus=args.n_cpus, 
        mem=args.mem, 
        execution_timeout=args.execution_timeout
        )
    # don't use debug=True --> it will run main twice and thus start double the containers
    app.run(debug=False, host='0.0.0.0', port=args.port) 


# TODO:
# how to mount data at runtime into the container? idea: mount a (read only) 
# folder into the container at startup and copy the data in there. before starting 
# the kernel we could cp the necessary data into the pwd.