import streamlit as st
import pandas as pd
import pandas as pd
from tqdm import tqdm

import torch
from sentence_transformers import SentenceTransformer
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
)

import streamlit_scrollable_textbox as stx


@st.experimental_singleton
def get_data():
    data = pd.read_csv("earnings_calls_sentencewise.csv")
    return data


# Initialize models from HuggingFace


@st.experimental_singleton
def get_t5_model():
    return pipeline("summarization", model="t5-small", tokenizer="t5-small")


@st.experimental_singleton
def get_flan_t5_model():
    return pipeline(
        "summarization", model="google/flan-t5-small", tokenizer="google/flan-t5-small"
    )


@st.experimental_singleton
def get_mpnet_embedding_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SentenceTransformer(
        "sentence-transformers/all-mpnet-base-v2", device=device
    )
    model.max_seq_length = 512
    return model


@st.experimental_singleton
def get_sgpt_embedding_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SentenceTransformer(
        "Muennighoff/SGPT-125M-weightedmean-nli-bitfit", device=device
    )
    model.max_seq_length = 512
    return model


@st.experimental_memo
def save_key(api_key):
    return api_key


def query_pinecone(query, top_k, model, index, year, quarter, ticker, threshold=0.5):
    # generate embeddings for the query
    xq = model.encode([query]).tolist()
    # search pinecone index for context passage with the answer
    xc = index.query(
        xq,
        top_k=top_k,
        filter={
            "Year": int(year),
            "Quarter": {"$eq": quarter},
            "Ticker": {"$eq": ticker},
        },
        include_metadata=True,
    )
    # filter the context passages based on the score threshold
    filtered_matches = []
    for match in xc["matches"]:
        if match["score"] >= threshold:
            filtered_matches.append(match)
    xc["matches"] = filtered_matches
    return xc


def format_query(query_results):
    # extract passage_text from Pinecone search result
    context = [result["metadata"]["Text"] for result in query_results["matches"]]
    return context


def sentence_id_combine(data, query_results, lag=2):
    # Extract sentence IDs from query results
    ids = [result["metadata"]["Sentence_id"] for result in query_results["matches"]]
    # Generate new IDs by adding a lag value to the original IDs
    new_ids = [id + i for id in ids for i in range(-lag, lag + 1)]
    # Remove duplicates and sort the new IDs
    new_ids = sorted(set(new_ids))
    # Create a list of lookup IDs by grouping the new IDs in groups of lag*2+1
    lookup_ids = [
        new_ids[i : i + (lag * 2 + 1)] for i in range(0, len(new_ids), lag * 2 + 1)
    ]
    # Create a list of context sentences by joining the sentences corresponding to the lookup IDs
    context_list = [
        ". ".join(data.Text.iloc[lookup_id].to_list()) for lookup_id in lookup_ids
    ]
    return context_list


def text_lookup(data, sentence_ids):
    context = ". ".join(data.iloc[sentence_ids].to_list())
    return context


def gpt3(query, result):
    response = openai.Completion.create(
        model="text-davinci-003",
        prompt=f"""Context information is below. \n"
    "---------------------\n"
    "{result}"
    "\n---------------------\n"
    "Given the context information and prior knowledge, answer this question: {query}. \n"
    "Try to include as many key details as possible and format the answer in points. \n" """,
        temperature=0.1,
        max_tokens=512,
        top_p=1.0,
        frequency_penalty=0.0,
        presence_penalty=1,
    )
    return response.choices[0].text


# Transcript Retrieval


def retrieve_transcript(data, year, quarter, ticker):
    row = (
        data.loc[
            (data.Year == int(year))
            & (data.Quarter == quarter)
            & (data.Ticker == ticker),
            ["Year", "Month", "Date", "Ticker"],
        ]
        .drop_duplicates()
        .iloc[0]
    )
    # convert row to a string and join values with "-"
    row_str = "-".join(row.astype(str)) + ".txt"
    open_file = open(
        f"Transcripts/{ticker}/{row_str}",
        "r",
    )
    file_text = open_file.read()
    return file_text