Spaces:
Sleeping
Sleeping
File size: 1,660 Bytes
b2f9c3c 0fc39f5 b2f9c3c ade21fe c72cbc0 b2f9c3c 874d7b1 b2f9c3c 105b9cf 0fc39f5 ffa5f5d b2f9c3c 0fc39f5 b2f9c3c 1060e26 94ba614 90f2a6c 59eb01b 90f2a6c 1060e26 a29a79d 1060e26 a29a79d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI
import pickle
import pandas as pd
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import torch
corpus = pickle.load(open("./corpus/all_embeddings.pickle", "rb"))
label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb"))
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
df = pd.DataFrame(data={"label": pickle.load(open("./corpus/y_all.pickle", "rb"))})
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Disease(BaseModel):
id: int
name: str
score: float
@app.get("/")
def greet_json():
return {"Hello": "World!"}
# @app.post("/")
# async def greet_post():
# return {"Hello": "Post World!"}
@app.post("/", response_model=list[Disease])
async def predict(query: str):
query_embedding = model.encode(query).astype('float')
similarity_vectors = model.similarity(query_embedding, corpus)[0]
print("Similarity Vector Shape: ", similarity_vectors.shape)
scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
print("Scores Shape: ", scores.shape)
print("Indicies Shape: ", indicies.shape)
id_ = df.iloc[indicies]
id_ = id_.drop_duplicates("label")
scores = scores[id_.index]
diseases = label_encoder.inverse_transform(id_.label.values)
id_ = id_.label.values
diseases = [dict({"id": value[0], "name": value[1], "score" : value[2]}) for value in zip(id_, diseases, scores)]
return diseases |