mudaza's picture
update code
a29a79d
raw
history blame
1.66 kB
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import pickle
import pandas as pd
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import torch
corpus = pickle.load(open("./corpus/all_embeddings.pickle", "rb"))
label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb"))
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
df = pd.DataFrame(data={"label": pickle.load(open("./corpus/y_all.pickle", "rb"))})
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Disease(BaseModel):
id: int
name: str
score: float
@app.get("/")
def greet_json():
return {"Hello": "World!"}
# @app.post("/")
# async def greet_post():
# return {"Hello": "Post World!"}
@app.post("/", response_model=list[Disease])
async def predict(query: str):
query_embedding = model.encode(query).astype('float')
similarity_vectors = model.similarity(query_embedding, corpus)[0]
print("Similarity Vector Shape: ", similarity_vectors.shape)
scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
print("Scores Shape: ", scores.shape)
print("Indicies Shape: ", indicies.shape)
id_ = df.iloc[indicies]
id_ = id_.drop_duplicates("label")
scores = scores[id_.index]
diseases = label_encoder.inverse_transform(id_.label.values)
id_ = id_.label.values
diseases = [dict({"id": value[0], "name": value[1], "score" : value[2]}) for value in zip(id_, diseases, scores)]
return diseases