import time

from langchain.agents import AgentType
from langchain.agents import initialize_agent
from langchain.callbacks import get_openai_callback
from langchain.chat_models import ChatOpenAI

from nodes.Worker import *
from utils.CustomDocstoreExplorer import CustomDocstoreExplorer
from utils.util import *


class ReactBase:
    def __init__(self, fewshot, model_name="text-davinci-002", max_iter=8, verbose=True):
        self.model_name = model_name
        self.max_iter = max_iter
        self.verbose = verbose
        self.fewshot = fewshot
        self.tools = self._load_tools()
        if model_name in OPENAI_COMPLETION_MODELS:
            self.agent = initialize_agent(self.tools,
                                          OpenAI(temperature=0, model_name=self.model_name),
                                          agent=AgentType.REACT_DOCSTORE,
                                          verbose=self.verbose,
                                          return_intermediate_steps=True,
                                          max_iterations=max_iter)
        elif model_name in OPENAI_CHAT_MODELS:
            self.agent = initialize_agent(self.tools,
                                          ChatOpenAI(temperature=0, model_name=self.model_name),
                                          agent=AgentType.REACT_DOCSTORE,
                                          verbose=self.verbose,
                                          return_intermediate_steps=True,
                                          max_iterations=max_iter)
        self.agent.agent.llm_chain.prompt.template = fewshot

    def run(self, prompt):
        self.reset()
        result = {}
        with get_openai_callback() as cb:
            st = time.time()
            response = self.agent(prompt)
            result["wall_time"] = time.time() - st
            result["input"] = response["input"]
            result["output"] = response["output"]
            result["intermediate_steps"] = response["intermediate_steps"]
            result["tool_usage"] = self._parse_tool(response["intermediate_steps"])
            result["total_tokens"] = cb.total_tokens
            result["prompt_tokens"] = cb.prompt_tokens
            result["completion_tokens"] = cb.completion_tokens
            result["total_cost"] = cb.total_cost
            result["steps"] = len(response["intermediate_steps"]) + 1
            result["token_cost"] = result["total_cost"]
            result["tool_cost"] = 0
        return result

    def _load_tools(self):
        docstore = CustomDocstoreExplorer(Wikipedia())
        return [
            Tool(
                name="Search",
                func=docstore.search,
                description="useful for when you need to ask with search"
            ),
            Tool(
                name="Lookup",
                func=docstore.lookup,
                description="useful for when you need to ask with lookup"
            )
        ]

    def reset(self):
        self.tools = self._load_tools()
        if self.model_name in OPENAI_COMPLETION_MODELS:
            self.agent = initialize_agent(self.tools,
                                          OpenAI(temperature=0, model_name=self.model_name),
                                          agent=AgentType.REACT_DOCSTORE,
                                          verbose=self.verbose,
                                          return_intermediate_steps=True,
                                          max_iterations=self.max_iter)
        elif self.model_name in OPENAI_CHAT_MODELS:
            self.agent = initialize_agent(self.tools,
                                          ChatOpenAI(temperature=0, model_name=self.model_name),
                                          agent=AgentType.REACT_DOCSTORE,
                                          verbose=self.verbose,
                                          return_intermediate_steps=True,
                                          max_iterations=self.max_iter)
        self.agent.agent.llm_chain.prompt.template = self.fewshot

    def _parse_tool(self, intermediate_steps):
        tool_usage = {"search": 0, "lookup": 0}
        for step in intermediate_steps:
            if step[0].tool == "Search":
                tool_usage["search"] += 1
            if step[0].tool == "Lookup":
                tool_usage["lookup"] += 1
        return tool_usage


class ReactExtraTool(ReactBase):
    def __init__(self, model_name="text-davinci-003", available_tools=["Google", "Calculator"], fewshot="\n",
                 verbose=True):
        self.model_name = model_name
        self.verbose = verbose
        self.fewshot = fewshot
        self.available_tools = available_tools
        self.tools = self._load_tools()
        self.agent = initialize_agent(self.tools,
                                      OpenAI(temperature=0, model_name=self.model_name),
                                      agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                                      verbose=self.verbose,
                                      return_intermediate_steps=True)

    def run(self, prompt):
        self.reset()
        result = {}
        with get_openai_callback() as cb:
            st = time.time()
            response = self.agent(prompt)
            result["wall_time"] = time.time() - st
            result["input"] = response["input"]
            result["output"] = response["output"]
            result["intermediate_steps"] = response["intermediate_steps"]
            result["tool_usage"] = self._parse_tool(response["intermediate_steps"])
            result["total_tokens"] = cb.total_tokens + result["tool_usage"]["llm-math_token"]
            result["prompt_tokens"] = cb.prompt_tokens
            result["completion_tokens"] = cb.completion_tokens
            result["total_cost"] = cb.total_cost + result["tool_usage"]["llm-math_token"] * 0.000002 + \
                                   result["tool_usage"]["serpapi"] * 0.01  # Developer Plan
            result["steps"] = len(response["intermediate_steps"]) + 1
            result["token_cost"] = result["total_cost"]
            result["tool_cost"] = 0

        return result

    def _load_tools(self):
        tools = []
        for tool_name in self.available_tools:
            tool_cls = WORKER_REGISTRY[tool_name]
            tools += [Tool(name=tool_name,
                           func=tool_cls.run,
                           description=tool_cls.description)]
        return tools

    def reset(self):
        self.tools = self._load_tools()
        self.agent = initialize_agent(self.tools,
                                      OpenAI(temperature=0, model_name=self.model_name),
                                      agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                                      verbose=self.verbose,
                                      return_intermediate_steps=True)
        self.agent.agent.llm_chain.prompt.template = PREFIX + self._generate_tool_prompt() + "\n" + self.fewshot

    def _parse_tool(self, intermediate_steps):
        tool_usage = {"serpapi": 0, "llm-math_token": 0}
        for step in intermediate_steps:
            if step[0].tool == "Search":
                tool_usage["serpapi"] += 1
            if step[0].tool == "Calculator":
                tool_usage["llm-math_token"] += len(step[0].tool_input + step[1]) // 4  # 4 chars per token
        return tool_usage

    def _get_worker(self, name):
        if name in WORKER_REGISTRY:
            return WORKER_REGISTRY[name]
        else:
            raise ValueError("Worker not found")

    def _generate_tool_prompt(self):
        prompt = "Tools can be one of the following:\n"
        for name in self.available_tools:
            worker = self._get_worker(name)
            prompt += f"{worker.name}[input]: {worker.description}\n"
        return prompt + "\n"


PREFIX = """Answer the following questions as best you can. You have access to the following tools:
"""