from sentence_transformers import SentenceTransformer from fastapi import FastAPI import pickle import pandas as pd from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware import torch corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb")) # label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb")) # model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb"))) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Disease(BaseModel): id: int name: str url: str score: float class Symptoms(BaseModel): query: str # @app.get("/") # def greet_json(): # return {"Hello": "World!"} @app.post("/", response_model=list[Disease]) async def predict(symptoms: Symptoms): query_embedding = model.encode(symptoms.query).astype('float') similarity_vectors = model.similarity(query_embedding, corpus)[0] scores, indicies = torch.topk(similarity_vectors, k=len(corpus)) # id_ = df.iloc[indicies].reset_index(drop=True) df = df.iloc[indicies] # id_ = id_.drop_duplicates("label") df["scores"] = scores # scores = scores[id_.index] # diseases = label_encoder.inverse_transform(id_.label.values) # id_ = id_.label.values diseases = [dict({"id": value[0], "name": value[1], "score" : value[2], "url" : value[3], }) for value in zip(df.index, df["name"], df["scores"], df["url"])] return diseases