from sentence_transformers import SentenceTransformer from fastapi import FastAPI import pickle import pandas as pd from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware import torch corpus = pickle.load(open("./corpus/all_embeddings.pickle", "rb")) label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb")) model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") df = pd.DataFrame(data={"label": pickle.load(open("./corpus/y_all.pickle", "rb"))}) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Disease(BaseModel): id: int name: str score: float @app.get("/") def greet_json(): return {"Hello": "World!"} # @app.post("/") # async def greet_post(): # return {"Hello": "Post World!"} @app.post("/", response_model=list[Disease]) async def predict(query: str): query_embedding = model.encode(query).astype('float') similarity_vectors = model.similarity(query_embedding, corpus)[0] print("Similarity Vector Shape: ", similarity_vectors.shape) scores, indicies = torch.topk(similarity_vectors, k=len(corpus)) print("Scores Shape: ", scores.shape) print("Indicies Shape: ", indicies.shape) id_ = df.iloc[indicies] id_ = id_.drop_duplicates("label") scores = scores[id_.index] diseases = label_encoder.inverse_transform(id_.label.values) id_ = id_.label.values diseases = [dict({"id": value[0], "name": value[1], "score" : value[2]}) for value in zip(id_, diseases, scores)] return diseases