"""Text-to-SQL running."""
import asyncio
import json
import re
import time
from typing import cast
import duckdb

import structlog
from manifest import Manifest
from manifest.response import Response, Usage
from prompt_formatters import RajkumarFormatter, MotherDuckFormatter
from schema import DEFAULT_TABLE_NAME, TextToSQLModelResponse, TextToSQLParams
from tqdm.auto import tqdm

logger = structlog.get_logger()


def clean_whitespace(sql: str) -> str:
    """Clean whitespace."""
    return re.sub(r"[\t\n\s]+", " ", sql)


def instruction_to_sql(
    params: TextToSQLParams,
    extra_context: list[str],
    manifest: Manifest,
    prompt_formatter: RajkumarFormatter = None,
    overwrite_manifest: bool = False,
    max_tokens: int = 300,
    temperature: float = 0.1,
    stop_sequences: list[str] | None = None,
    num_beams: int = 1,
) -> TextToSQLModelResponse:
    """Parse the instruction to a sql command."""
    return instruction_to_sql_list(
        params=[params],
        extra_context=[extra_context],
        manifest=manifest,
        prompt_formatter=prompt_formatter,
        overwrite_manifest=overwrite_manifest,
        max_tokens=max_tokens,
        temperature=0.1,
        stop_sequences=stop_sequences,
        num_beams=num_beams,
    )[0]

def run_motherduck_prompt_sql(params: list[TextToSQLParams]) -> list[TextToSQLModelResponse]:
    results = []
    for param in params:
        con = duckdb.connect('md:')
        try:
            sql_query = con.execute("CALL prompt_sql(?);", [param.instruction]).fetchall()[0][0]
        except Exception as e:
            print(e)
            sql_query = "SELECT * FROM hn.hacker_news LIMIT 1";
        usage = Usage(
                completion_tokens = 0,
                prompt_tokens = 0,
                total_tokens = 0
        )
        model_response = TextToSQLModelResponse(
            output=sql_query,
            raw_output=sql_query,
            final_prompt=param.instruction,
            usage=usage,
        )
        results.append(model_response)
    return results



def instruction_to_sql_list(
    params: list[TextToSQLParams],
    extra_context: list[list[str]],
    manifest: Manifest,
    prompt_formatter: RajkumarFormatter = None,
    overwrite_manifest: bool = False,
    max_tokens: int = 300,
    temperature: float = 0.1,
    stop_sequences: list[str] | None = None,
    num_beams: int = 1,
    verbose: bool = False,
) -> list[TextToSQLModelResponse]:
    """Parse the list of instructions to sql commands.

    Connector is used for default retry handlers only.
    """
    if type(prompt_formatter) is MotherDuckFormatter:
        return run_motherduck_prompt_sql(params)

    if prompt_formatter is None:
        raise ValueError("Prompt formatter is required.")

    def construct_params(
        params: TextToSQLParams,
        context: list[str],
    ) -> str | list[dict]:
        """Turn params into prompt."""
        if prompt_formatter.clean_whitespace:
            instruction = clean_whitespace(params.instruction)
        else:
            instruction = params.instruction

        table_texts = prompt_formatter.format_all_tables(
            params.tables, instruction=instruction
        )
        # table_texts can be list of chat messages. Only join list of str.
        if table_texts:
            if isinstance(table_texts[0], str):
                table_text = prompt_formatter.table_sep.join(table_texts)
            else:
                table_text = table_texts
        else:
            table_text = ""

        if context:
            context_text = prompt_formatter.format_retrieved_context(context)
        else:
            context_text = "" if isinstance(table_text, str) else []
        prompt = prompt_formatter.format_prompt(
            instruction,
            table_text,
            context_text,
        )
        return prompt

    # If no inputs, return nothing
    if not params:
        return []

    # Stitch together demonstrations and params
    prompts: list[str | list[dict]] = []
    for i, param in tqdm(
        enumerate(params),
        total=len(params),
        desc="Constructing prompts",
        disable=not verbose,
    ):
        predict_str = construct_params(param, extra_context[i] if extra_context else [])
        if isinstance(predict_str, str):
            prompt = predict_str.lstrip()
        else:
            prompt = predict_str
        prompts.append(prompt)

    manifest_params = dict(
        max_tokens=max_tokens,
        overwrite_cache=overwrite_manifest,
        num_beams=num_beams,
        logprobs=5,
        temperature=0.1,
        do_sample=False if 0.1 <= 0 else True,
        stop_sequences=stop_sequences or prompt_formatter.stop_sequences,
    )

    ret: list[TextToSQLModelResponse] = []
    if len(params) == 1:
        prompt = prompts[0]
        success = False
        retries = 0
        while not success and retries < 5:
            try:
                model_response = _run_manifest(
                    prompt,
                    manifest_params,
                    prompt_formatter,
                    manifest,
                    stop_sequences=stop_sequences,
                )
                success = True
            except:
                retries +=1

        usage = model_response.usage
        model_response.usage = usage
        ret.append(model_response)
    else:
        # We do not handle retry logic on parallel requests right now
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        response = cast(
            Response,
            loop.run_until_complete(
                manifest.arun_batch(
                    prompts,
                    **manifest_params,  # type: ignore
                ),
            ),
        )
        loop.close()

        response_usage = response.get_usage()
        response_text = response.get_parsed_response()
        for prompt, resp in zip(prompts, response_text):
            # This will restitch the query in the case we force it to start with SELECT
            sql_query = prompt_formatter.format_model_output(cast(str, resp), prompt)
            for token in stop_sequences:
                sql_query = sql_query.split(token)[0]
            logger.info(f"FINAL OUTPUT: {sql_query}")
            ret.append(
                TextToSQLModelResponse(
                    output=sql_query,
                    raw_output=cast(str, resp),
                    final_prompt=prompt,
                    usage=response_usage,
                )
            )

    return ret


def _run_manifest(
    prompt: str | list[str],
    manifest_params: dict,
    prompt_formatter: RajkumarFormatter,
    manifest: Manifest,
    stop_sequences: list[str] | None = None,
) -> TextToSQLModelResponse:
    """Run manifest for prompt format."""
    logger.info(f"PARAMS: {manifest_params}")
    if isinstance(prompt, list):
        for p in prompt:
            logger.info(f"PROMPT: {p['role']}: {p['content']}")
    else:
        logger.info(f"PROMPT: {prompt}")
    start_time = time.time()
    # Run result
    response = cast(
        Response,
        manifest.run(
            prompt,
            return_response=True,
            client_timeout=1800,
            **manifest_params,  # type: ignore
        ),
    )
    logger.info(f"TIME: {time.time() - start_time: .2f}")

    response_usage = response.get_usage_obj()
    summed_usage = Usage()
    for usage in response_usage.usages:
        summed_usage.completion_tokens += usage.completion_tokens
        summed_usage.prompt_tokens += usage.prompt_tokens
        summed_usage.total_tokens += usage.total_tokens
    # This will restitch the query in the case we force it to start with SELECT
    sql_query = prompt_formatter.format_model_output(
        cast(str, response.get_response()), prompt
    )

    for token in stop_sequences:
        sql_query = sql_query.split(token)[0]
    logger.info(f"OUTPUT: {sql_query}")
    model_response = TextToSQLModelResponse(
        output=sql_query,
        raw_output=cast(str, response.get_response()),
        final_prompt=prompt,
        usage=summed_usage,
    )
    return model_response