import logging
import os
from collections import Counter, defaultdict
import multiprocessing
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import gc

from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse

from api.code_execution import untrusted_check

Result = Tuple[str, List[bool]]

def create_app() -> FastAPI:

    level = os.environ.get("LOG_LEVEL", default=logging.INFO)
    logging.basicConfig(level=level)
    logger = logging.getLogger(__name__)

    app = FastAPI()

    @app.get("/")
    def root():
        return RedirectResponse("/docs")

    @app.get("/health", status_code=204)
    def health():
        return

    @app.post("/evaluate/")
    async def evaluate(
        samples: List[dict],
        calibrate: bool = True,
        parallel: int = -1,
        min_time_limit: float = 1,
        max_as_limit: int = 30 * 1024,
        max_data_limit: int = 30 * 1024,
        max_stack_limit: int = 10,
        no_gt: bool = True,
    ) -> dict:
        """
        Evaluate the correctness of the solutions in the given samples data.
        """
        if parallel < 1:
            n_workers = min(2, multiprocessing.cpu_count() // 2)
            if n_workers < 1:
                n_workers = 1

        else:
            n_workers = parallel

        if not no_gt:
            expected_time = get_groundtruth()
        else:
            expected_time = {}

        results = {
            "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
            "eval": {},
        }

        with ProcessPoolExecutor(max_workers=n_workers) as executor:
            futures = []
            completion_id = Counter()
            n_samples = 0
            eval_results = defaultdict(list)  # task_id ->
            remainings = set()

            for i, sample in enumerate(samples):
                # TODO: investigate why HTTPException detail is not passed to client.

                for key in ["task_id", "res_id", "test", "solution", "entry_point"]:
                    if key not in sample:
                        raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!")

                if not isinstance(sample["solution"], str):
                    raise HTTPException(status_code=400, detail="Solution must be a string!")

                sample["_identifier"] = (
                    sample["task_id"] + f" (line {i+1} )"
                )

                task_id = sample["task_id"]
                
                solution = sample["solution"]

                if calibrate:
                    solution = sample["code_prompt"] + "\n    pass\n" + solution
                remainings.add(sample["_identifier"])
                args = (
                    completion_id[task_id],
                    sample["res_id"],
                    task_id,
                    solution,
                    sample["test"],
                    sample["entry_point"],
                    max_as_limit,
                    max_data_limit,
                    max_stack_limit,
                    sample["_identifier"],
                    min_time_limit,
                    expected_time.get(task_id) if expected_time.get(task_id) else 20
                )
                futures.append(executor.submit(check_correctness, *args))
                completion_id[task_id] += 1
                n_samples += 1

            assert n_samples == len(remainings), "Missing problems in unfinished"
            #assert len(completion_id) == len(problems), "Missing problems in samples"

            for future in as_completed(futures):
                result = future.result()
                remainings.remove(result["_identifier"])
                eval_results[result["task_id"]].append(result)
                del future, result
                gc.collect()
        
        # sort the results for each problem by completion_id
        for task_id, task_results in eval_results.items():
            task_results.sort(key=lambda x: x["completion_id"])
            results["eval"][task_id] = []
            for res in task_results:
                stat, details = res["base"]
                results["eval"][task_id].append(
                    {
                        "res_id": res["res_id"],
                        "task_id": task_id,
                        "solution": res["solution"],
                        "status": stat,
                        "details": details,
                    }
                )
        return results

    return app

def check_correctness(
    completion_id: int,
    res_id: int,
    task_id: str,
    solution: str,
    test: str,
    entry_point: str,
    max_as_limit: float,
    max_data_limit: float,
    max_stack_limit: float,
    identifier=None,
    min_time_limit: float = 0.1,
    gt_time_limit: float = 2.0,
) -> Dict[str, Result]:  
    ret = {
        "completion_id": completion_id,
        "res_id": res_id,
        "task_id": task_id,
        "_identifier": identifier,
        "solution": solution,
    }
    ret["base"] = untrusted_check(
        solution,
        test,
        entry_point,
        max_as_limit,
        max_data_limit,
        max_stack_limit,
        min_time_limit,
        gt_time_limit,
    )
    return ret


def get_groundtruth():
    raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!")