|
from sentence_transformers import SentenceTransformer |
|
from fastapi import FastAPI |
|
import pickle |
|
import pandas as pd |
|
from pydantic import BaseModel |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import torch |
|
|
|
corpus = pickle.load(open("./corpus/all_embeddings.pickle", "rb")) |
|
label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb")) |
|
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") |
|
df = pd.DataFrame(data={"label": pickle.load(open("./corpus/y_all.pickle", "rb"))}) |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class Disease(BaseModel): |
|
id: int |
|
name: str |
|
score: float |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/", response_model=list[Disease]) |
|
async def predict(query: str): |
|
query_embedding = model.encode(query).astype('float') |
|
similarity_vectors = model.similarity(query_embedding, corpus)[0] |
|
print("Similarity Vector Shape: ", similarity_vectors.shape) |
|
scores, indicies = torch.topk(similarity_vectors, k=len(corpus)) |
|
print("Scores Shape: ", scores.shape) |
|
print("Indicies Shape: ", indicies.shape) |
|
id_ = df.iloc[indicies] |
|
id_ = id_.drop_duplicates("label") |
|
scores = scores[id_.index] |
|
diseases = label_encoder.inverse_transform(id_.label.values) |
|
id_ = id_.label.values |
|
diseases = [dict({"id": value[0], "name": value[1], "score" : value[2]}) for value in zip(id_, diseases, scores)] |
|
return diseases |