from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import pickle
import pandas as pd
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import torch

corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb")).astype("float")
# label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb"))
# model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb")))

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class Disease(BaseModel):
    id: int
    name: str
    url: str
    score: float

class Symptoms(BaseModel):
    query: str

@app.get("/")
def home():
    print(df.iloc[0])
    return {"Hello": "World!"}

@app.post("/", response_model=list[Disease])
async def predict(symptoms: Symptoms):
    query_embedding = model.encode(symptoms.query).astype('float')
    similarity_vectors = model.similarity(query_embedding, corpus)[0]
    scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
    # id_ = df.iloc[indicies].reset_index(drop=True)
    ls = df.iloc[indicies].copy()
    # print(ls.iloc[0])
    # id_ = id_.drop_duplicates("label")
    ls["scores"] = scores
    # scores = scores[id_.index]
    # diseases = label_encoder.inverse_transform(id_.label.values)
    # id_ = id_.label.values
    diseases = [dict({"id": value[0],
                      "name": value[1],
                      "url" : value[2],
                      "score" : value[3]})
                      for value in zip(ls.index, ls["name"], ls["url"], ls["scores"])]
    return diseases