import ast
import json
import os
from pathlib import Path

import openai
import pandas as pd
import numpy as np
from tqdm import tqdm

from annoy import AnnoyIndex

from openai_function_utils.openai_function_interface import OPENAI_AVAILABLE_FUNCTIONS, OPENAI_FUNCTIONS_DEFINITIONS
DEBUG_PRINT = False
# openai.api_key = OPENAI_KEY
# openai.organization = 'org-dsEkob5KeBBq3lbBLhnCXcJt'


def get_embeddings(input):
    response = openai.Embedding.create(model="text-embedding-ada-002", input=input)
    return response['data'][0]['embedding']


def debug_print(*args, **kwargs):
    if DEBUG_PRINT:
        print(*args, **kwargs)


def transform_user_question(question, model):
    messages = [
        {"role": "system",
         "content": "You are a helpful assistant for ChatGPT that will formulate user's input question to a version that is more understandable by ChatGPT for answering questions related to a research lab."},
        {"role": "user",
         "content": f"Formulate this question into a version that is more understandable by ChatGPT: \"{question}\""}
    #     "content": f"Formulate this question into a version that is more understandable by ChatGPT and is more suitable for embedding retrieval (i.e. we will use the embedding of the re-formulated question to retrieve related documents): \"{question}\""}
    ]
    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        max_tokens=200
    )
    chagpt_question = response["choices"][0]["message"].content
    return chagpt_question


def answer_with_gpt3_with_function_calls(input_text, question, model):
    question = f"Based on the input text: {input_text}\n Give me answers for this question: {question}"
    messages = [
        {
            "role": "system",
            "content": "".join([
                "You are a helpful assistant for ChatGPT that will answer the user's questions. "
            ])
        },
        {
            "role": "user",
            "content": question
        }
    ]

    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        functions=OPENAI_FUNCTIONS_DEFINITIONS,
        max_tokens=200
    )
    response_message = response["choices"][0]["message"]

    messages.append(
        {
            "role": "assistant",
            "content": response_message.get("content"),
            "function_call": response_message.get("function_call"),
        }
    )

    # Check if GPT wanted to call a function
    if response_message.get("function_call"):
        # Call the function
        # Note: the JSON response may not always be valid; be sure to handle errors
        available_functions = OPENAI_AVAILABLE_FUNCTIONS  # only one function in this example, but you can have multiple
        function_name = response_message["function_call"]["name"]

        # Step 4: send the info on the function call and function response to GPT
        if function_name == "semantic_search":
            # print("Running semantic search")
            # print(response_message["function_call"]["arguments"])
            function_args = json.loads(response_message["function_call"]["arguments"])
            embedding = get_embeddings(function_args['query'])
            function_response = search_document(embedding, 3)
            messages.append({
                "role": "function",
                "name": "semantic_search",
                "content": function_response
            })
            second_response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
            )  # get a new response from GPT where it can see the function response
            return second_response.choices[0].message.content
        else:
            function_to_call = available_functions[function_name]
            function_args = json.loads(response_message["function_call"]["arguments"])
            function_response = function_to_call(**function_args)
            messages.append(response_message)  # extend conversation with assistant's reply
            messages.append(
                {
                    "role": "function",
                    "name": function_name,
                    "content": function_response,
                }
            )  # extend conversation with function response
            # messages.append(
            #     {
            #         "role": "user",
            #         "content": "give me publication of J Coleman"
            #     }
            # )
            print("DEBUG: messages", messages)
            second_response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
            )  # get a new response from GPT where it can see the function response
            return second_response.choices[0].message.content
    else:
        return response.choices[0].message.content


def answer_with_gpt3(input_text, question):
    messages = [{"role": "system",
                 "content": "You are an intelligent chatbot for answering user's questions related to a research lab."}]
    message = f"Based on the input text: {input_text}\n Give me answers for this question: {question}"
    messages.append({"role": "user", "content": message})
    chat = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=messages,
        functions=OPENAI_FUNCTIONS_DEFINITIONS,
        max_tokens=200
    )
    reply = chat.choices[0].message.content
    return reply


def search_document(user_question_embed: list, top_k: int = 1):
    csv_filename = 'database/document_name_to_embedding.csv'
    if not os.path.exists(csv_filename):
        print("This won't happen!")
        return

    df = pd.read_csv(csv_filename)
    # Convert the embedding column from string to list/array
    df['embedding'] = df['embedding'].apply(ast.literal_eval).apply(np.array)

    # Calculate cosine similarity
    user_question_norm = np.linalg.norm(user_question_embed)
    similarities = {}
    for _, row in df.iterrows():
        dot_product = np.dot(user_question_embed, row['embedding'])
        embedding_norm = np.linalg.norm(row['embedding'])
        cosine_similarity = dot_product / (user_question_norm * embedding_norm)
        similarities[row['original_filename']] = cosine_similarity

    # Rank documents by similarity
    ranked_documents = sorted(similarities.items(), key=lambda x: x[1], reverse=True)

    debug_print("Ranked documents by similarity:", ranked_documents)

    # Get the most similar article
    for i in range(top_k):
        best_document_filename = ranked_documents[i][0]
        with open(best_document_filename, 'rb') as f:
            document_content = f.read().decode('utf-8')
        debug_print("document_content: ", document_content)
    return document_content


def search_document_annoy(user_question_embed: list, top_k: int, metric):
    csv_filename = 'database/document_name_to_embedding.csv'
    if not os.path.exists(csv_filename):
        print("This won't happen!")
        return

    df = pd.read_csv(csv_filename)
    # Convert the embedding column from string to list/array
    df['embedding'] = df['embedding'].apply(ast.literal_eval).apply(np.array)

    f = len(df['embedding'][0])  # Length of item vector that will be indexed

    t = AnnoyIndex(f, metric)
    for i in range(len(df)):
        v = df['embedding'][i]
        t.add_item(i, v)

    t.build(10)  # 10 trees
    t.save('test.ann')

    u = AnnoyIndex(f, metric)
    u.load('test.ann')  # will just mmap the file
    ret = u.get_nns_by_vector(user_question_embed, top_k)  # will find top 3 nearest neighbors
    debug_print(df['original_filename'][ret[0]])
    document_content = ""
    for name in ret:
        best_document_filename = df['original_filename'][name]
        with open(best_document_filename, 'rb') as f:
            document_content += f.read().decode('utf-8')
    debug_print("document_content: ", document_content)
    return document_content


def get_document_embeddings(path: str, all_fns: list):
    all_embeddings = []
    all_embedding_fns = []
    all_original_filename = []

    output_sub_dir = path.split('database/original_documents/')
    output_sub_dir = '' if len(output_sub_dir) == 1 else output_sub_dir[1]

    output_dir = os.path.join('database/embeddings', output_sub_dir)

    Path(output_dir).mkdir(parents=True, exist_ok=True)

    for fn in tqdm(all_fns):
        document_name = fn.split('.')[0]
        original_filename = os.path.join(path, fn)
        try:
            with open(original_filename, 'rb') as fin:
                tmp_file = fin.read().decode('utf-8')
                embedding = get_embeddings(tmp_file)
                if embedding is not None:
                    embedding_fn = os.path.join(output_dir, document_name + '.json')
                    with open(embedding_fn, 'w') as fout:
                        json.dump(embedding, fout)
                    all_original_filename.append(original_filename)
                    all_embedding_fns.append(embedding_fn)
                    all_embeddings.append(embedding)
        except Exception:
            print(
                f"Error when obtaining embedding vector for {original_filename}. The model's maximum context length is 8192 tokens. Please make sure the file is valid and file length is not too long.")

    return pd.DataFrame({
        'original_filename': all_original_filename,
        'embedding_filename': all_embedding_fns,
        'embedding': all_embeddings
    })


def util():
    model = "gpt-3.5-turbo"
    question = "Can you give me a paper about graph neural networks?"

    functions = [
        {
            "name": "semantic_search",
            "description": "does a semantic search over the documents based on query",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "The query to search for",
                    }
                },
                "required": ["query"],
            }
        },
    ]

    messages = [
        {
            "role": "system",
            "content": "".join([
                "You are a helpful assistant for ChatGPT that will answer the user's questions. ",
                "In order to do so, you may use semantic_search to find relevant documents. ",
            ])
        },
        {
            "role": "user",
            "content": question
        }
    ]

    while True:
        response = openai.ChatCompletion.create(
            model=model,
            messages=messages,
            max_tokens=200,
            functions=functions
        )
        response_message = response["choices"][0]["message"]
        messages.append(
            {
                "role": "assistant",
                "content": response_message.get("content"),
                "function_call": response_message.get("function_call"),
            }
        )

        if response_message.get("function_call"):
            function_args = json.loads(response_message["function_call"]["arguments"])
            embedding = get_embeddings(function_args['query'])
            function_response = search_document(embedding)
            messages.append({
                "role": "function",
                "name": "semantic_search",
                "content": function_response
            })
        else:
            print("Answering question")
            print(response_message["content"])
            return

def main():
    final_df = pd.DataFrame({})
    all_fn_list = os.walk('database/original_documents')

    for path, _, fn_list in all_fn_list:
        filename_to_embedding_df = get_document_embeddings(path, fn_list)
        final_df = pd.concat([final_df, filename_to_embedding_df], axis=0, ignore_index=True)

    final_df.to_csv('database/document_name_to_embedding.csv')


if __name__ == "__main__":
    main()