mudaza commited on
Commit
94ba614
·
1 Parent(s): 90f2a6c

update code

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -37,7 +37,7 @@ def greet_json():
37
  @app.post("/", response_model=list[Disease])
38
  async def predict(query: str):
39
  query_embedding = model.encode(query).astype('float')
40
- similarity_vectors = model.similarity(query_embedding, corpus)
41
  print("Similarity Vector Shape: ", similarity_vectors.shape)
42
  scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
43
  print("Scores Shape: ", scores.shape)
 
37
  @app.post("/", response_model=list[Disease])
38
  async def predict(query: str):
39
  query_embedding = model.encode(query).astype('float')
40
+ similarity_vectors = model.similarity(query_embedding, corpus)[0]
41
  print("Similarity Vector Shape: ", similarity_vectors.shape)
42
  scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
43
  print("Scores Shape: ", scores.shape)