update code
Browse files
app.py
CHANGED
@@ -37,7 +37,7 @@ def greet_json():
|
|
37 |
@app.post("/", response_model=list[Disease])
|
38 |
async def predict(query: str):
|
39 |
query_embedding = model.encode(query).astype('float')
|
40 |
-
similarity_vectors = model.similarity(query_embedding, corpus)
|
41 |
print("Similarity Vector Shape: ", similarity_vectors.shape)
|
42 |
scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
|
43 |
print("Scores Shape: ", scores.shape)
|
|
|
37 |
@app.post("/", response_model=list[Disease])
|
38 |
async def predict(query: str):
|
39 |
query_embedding = model.encode(query).astype('float')
|
40 |
+
similarity_vectors = model.similarity(query_embedding, corpus)[0]
|
41 |
print("Similarity Vector Shape: ", similarity_vectors.shape)
|
42 |
scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
|
43 |
print("Scores Shape: ", scores.shape)
|