from __future__ import annotations as _annotations

import json
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import AsyncGenerator

import asyncpg
import gradio as gr
from gradio.utils import get_space
import numpy as np
import pydantic_core
from gradio_webrtc import (
    AdditionalOutputs,
    ReplyOnPause,
    WebRTC,
    audio_to_bytes,
    get_twilio_turn_credentials,
)
from groq import Groq
from openai import AsyncOpenAI
from pydantic import BaseModel
from pydantic_ai import RunContext
from pydantic_ai.agent import Agent
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn

if not get_space():
    from dotenv import load_dotenv

    load_dotenv()

DOCS = json.load(open("gradio_docs.json"))

groq_client = Groq()
openai = AsyncOpenAI()


@dataclass
class Deps:
    openai: AsyncOpenAI
    pool: asyncpg.Pool


SYSTEM_PROMPT = (
    "You are an assistant designed to help users answer questions about Gradio. "
    "You have a retrieve tool that can provide relevant documentation sections based on the user query. "
    "Be curteous and helpful to the user but feel free to refuse answering questions that are not about Gradio. "
)


agent = Agent(
    "openai:gpt-4o",
    deps_type=Deps,
    system_prompt=SYSTEM_PROMPT,
)


class RetrievalResult(BaseModel):
    content: str
    ids: list[int]


@asynccontextmanager
async def database_connect(
    create_db: bool = False,
) -> AsyncGenerator[asyncpg.Pool, None]:
    server_dsn, database = (
        os.getenv("DB_URL"),
        "gradio_ai_rag",
    )
    if create_db:
        conn = await asyncpg.connect(server_dsn)
        try:
            db_exists = await conn.fetchval(
                "SELECT 1 FROM pg_database WHERE datname = $1", database
            )
            if not db_exists:
                await conn.execute(f"CREATE DATABASE {database}")
        finally:
            await conn.close()

    pool = await asyncpg.create_pool(f"{server_dsn}/{database}")
    try:
        yield pool
    finally:
        await pool.close()


@agent.tool
async def retrieve(context: RunContext[Deps], search_query: str) -> str:
    """Retrieve documentation sections based on a search query.

    Args:
        context: The call context.
        search_query: The search query.
    """
    print(f"create embedding for {search_query}")
    embedding = await context.deps.openai.embeddings.create(
        input=search_query,
        model="text-embedding-3-small",
    )

    assert (
        len(embedding.data) == 1
    ), f"Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}"
    embedding = embedding.data[0].embedding
    embedding_json = pydantic_core.to_json(embedding).decode()
    rows = await context.deps.pool.fetch(
        "SELECT id, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8",
        embedding_json,
    )
    content = "\n\n".join(f'# {row["title"]}\n{row["content"]}\n' for row in rows)
    ids = [row["id"] for row in rows]
    return RetrievalResult(content=content, ids=ids).model_dump_json()


async def stream_from_agent(
    audio: tuple[int, np.ndarray], chatbot: list[dict], past_messages: list
):
    question = groq_client.audio.transcriptions.create(
        file=("audio-file.mp3", audio_to_bytes(audio)),
        model="whisper-large-v3-turbo",
        response_format="verbose_json",
    ).text

    print("text", question)

    chatbot.append({"role": "user", "content": question})
    yield AdditionalOutputs(chatbot, gr.skip())

    async with database_connect(False) as pool:
        deps = Deps(openai=openai, pool=pool)
        async with agent.run_stream(
            question, deps=deps, message_history=past_messages
        ) as result:
            for message in result.new_messages():
                past_messages.append(message)
                if isinstance(message, ModelStructuredResponse):
                    for call in message.calls:
                        gr_message = {
                            "role": "assistant",
                            "content": "",
                            "metadata": {
                                "title": "🔍 Retrieving relevant docs",
                                "id": call.tool_id,
                            },
                        }
                        chatbot.append(gr_message)
                if isinstance(message, ToolReturn):
                    for gr_message in chatbot:
                        if (
                            gr_message.get("metadata", {}).get("id", "")
                            == message.tool_id
                        ):
                            paths = []
                            for d in DOCS:
                                tool_result = RetrievalResult.model_validate_json(
                                    message.content
                                )
                                if d["id"] in tool_result.ids:
                                    paths.append(d["path"])
                            paths = '\n'.join(list(set(paths)))
                            gr_message["content"] = (
                                f"Relevant Context:\n {paths}"
                            )
                yield AdditionalOutputs(chatbot, gr.skip())
            chatbot.append({"role": "assistant", "content": ""})
            async for message in result.stream_text():
                chatbot[-1]["content"] = message
                yield AdditionalOutputs(chatbot, gr.skip())
            data = await result.get_data()
            past_messages.append(ModelTextResponse(content=data))
            yield AdditionalOutputs(gr.skip(), past_messages)


with gr.Blocks() as demo:
    placeholder = """
<div style="display: flex; justify-content: center; align-items: center; gap: 1rem; padding: 1rem; width: 100%">
    <img src="/gradio_api/file=gradio_logo.png" style="max-width: 200px; height: auto">
    <div>
        <h1 style="margin: 0 0 1rem 0">Chat with Gradio Docs 🗣️</h1>
        <h3 style="margin: 0 0 0.5rem 0">
            Simple RAG agent over Gradio docs built with Pydantic AI.
        </h3>
        <h3 style="margin: 0">
            Ask any question about Gradio with your natural voice and get an answer!
        </h3>
    </div>
</div>
"""
    past_messages = gr.State([])
    chatbot = gr.Chatbot(
        label="Gradio Docs Bot",
        type="messages",
        placeholder=placeholder,
        avatar_images=(None, "gradio_logo.png"),
    )
    audio = WebRTC(
        label="Talk with the Agent",
        modality="audio",
        rtc_configuration=get_twilio_turn_credentials(),
        mode="send",
    )
    audio.stream(
        ReplyOnPause(stream_from_agent),
        inputs=[audio, chatbot, past_messages],
        outputs=[audio],
    )
    audio.on_additional_outputs(
        lambda c, s: (c, s),
        outputs=[chatbot, past_messages],
        queue=False,
        show_progress="hidden",
    )


if __name__ == "__main__":
    demo.launch(allowed_paths=["gradio_logo.png"])