Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
from fastapi import FastAPI | |
import pickle | |
import pandas as pd | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb")).astype("float") | |
# label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb")) | |
# model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') | |
df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb"))) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class Disease(BaseModel): | |
id: int | |
name: str | |
url: str | |
score: float | |
class Symptoms(BaseModel): | |
query: str | |
def home(): | |
print(df.iloc[0]) | |
return {"Hello": "World!"} | |
async def predict(symptoms: Symptoms): | |
query_embedding = model.encode(symptoms.query).astype('float') | |
similarity_vectors = model.similarity(query_embedding, corpus)[0] | |
scores, indicies = torch.topk(similarity_vectors, k=len(corpus)) | |
# id_ = df.iloc[indicies].reset_index(drop=True) | |
ls = df.iloc[indicies].copy() | |
# print(ls.iloc[0]) | |
# id_ = id_.drop_duplicates("label") | |
ls["scores"] = scores | |
# scores = scores[id_.index] | |
# diseases = label_encoder.inverse_transform(id_.label.values) | |
# id_ = id_.label.values | |
diseases = [dict({"id": value[0], | |
"name": value[1], | |
"url" : value[2], | |
"score" : value[3]}) | |
for value in zip(ls.index, ls["name"], ls["url"], ls["scores"])] | |
return diseases |