import json
import re

import openai
import pandas as pd
import requests
import spacy
import spacy_transformers
import streamlit_scrollable_textbox as stx
import torch
from InstructorEmbedding import INSTRUCTOR
from sentence_transformers import SentenceTransformer
from gradio_client import Client
from tqdm import tqdm
from transformers import (
    AutoModelForMaskedLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    T5ForConditionalGeneration,
    T5Tokenizer,
    pipeline,
)

import pinecone
import streamlit as st


@st.experimental_singleton
def get_data():
    data = pd.read_csv("earnings_calls_cleaned_metadata.csv")
    return data


# Initialize Spacy Model


@st.experimental_singleton
def get_spacy_model():
    return spacy.load("en_core_web_trf")


@st.experimental_singleton
def get_flan_alpaca_xl_model():
    model = AutoModelForSeq2SeqLM.from_pretrained(
        "/home/user/app/models/flan-alpaca-xl/"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "/home/user/app/models/flan-alpaca-xl/"
    )
    return model, tokenizer


# Initialize models from HuggingFace


@st.experimental_singleton
def get_t5_model():
    return pipeline("summarization", model="t5-small", tokenizer="t5-small")


@st.experimental_singleton
def get_flan_t5_model():
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
    return model, tokenizer


@st.experimental_singleton
def get_mpnet_embedding_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SentenceTransformer(
        "sentence-transformers/all-mpnet-base-v2", device=device
    )
    model.max_seq_length = 512
    return model


@st.experimental_singleton
def get_splade_sparse_embedding_model():
    model_sparse = "naver/splade-cocondenser-ensembledistil"
    # check device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_sparse)
    model_sparse = AutoModelForMaskedLM.from_pretrained(model_sparse)
    # move to gpu if available
    model_sparse.to(device)
    return model_sparse, tokenizer


@st.experimental_singleton
def get_sgpt_embedding_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SentenceTransformer(
        "Muennighoff/SGPT-125M-weightedmean-nli-bitfit", device=device
    )
    model.max_seq_length = 512
    return model


@st.experimental_singleton
def get_instructor_embedding_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = INSTRUCTOR("hkunlp/instructor-xl")
    return model


@st.experimental_singleton
def get_alpaca_model():
    client = Client("https://awinml-alpaca-cpp.hf.space")
    return client


@st.experimental_memo
def save_key(api_key):
    return api_key


# Text Generation


def gpt_turbo_model(prompt):
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "user", "content": prompt},
        ],
        temperature=0.01,
        max_tokens=1024,
    )
    return response["choices"][0]["message"]["content"]


def generate_text_flan_t5(model, tokenizer, input_text):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, temperature=0.5, max_length=512)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


# Entity Extraction


def generate_entities_flan_alpaca_inference_api(prompt):
    API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
    API_TOKEN = st.secrets["hg_key"]
    headers = {"Authorization": f"Bearer {API_TOKEN}"}
    payload = {
        "inputs": prompt,
        "parameters": {
            "do_sample": True,
            "temperature": 0.1,
            "max_length": 80,
        },
        "options": {"use_cache": False, "wait_for_model": True},
    }
    try:
        data = json.dumps(payload)
        # Key not used as headers=headers not passed
        response = requests.request("POST", API_URL, data=data)
        output = json.loads(response.content.decode("utf-8"))[0][
            "generated_text"
        ]
    except:
        output = ""
    print(output)
    return output


def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
    model_inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = model_inputs["input_ids"]
    generation_output = model.generate(
        input_ids=input_ids,
        temperature=0.1,
        top_p=0.5,
        max_new_tokens=1024,
    )
    output = tokenizer.decode(generation_output[0], skip_special_tokens=True)
    return output