mudaza commited on
Commit
70eb6a4
·
1 Parent(s): 0f19f2b

modified code and add files

Browse files
app.py CHANGED
@@ -6,10 +6,11 @@ from pydantic import BaseModel
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import torch
8
 
9
- corpus = pickle.load(open("./corpus/all_embeddings.pickle", "rb"))
10
- label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb"))
11
- model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
12
- df = pd.DataFrame(data={"label": pickle.load(open("./corpus/y_all.pickle", "rb"))})
 
13
 
14
  app = FastAPI()
15
 
@@ -24,27 +25,32 @@ app.add_middleware(
24
  class Disease(BaseModel):
25
  id: int
26
  name: str
 
27
  score: float
28
 
29
  class Symptoms(BaseModel):
30
  query: str
31
 
32
- @app.get("/")
33
- def greet_json():
34
- return {"Hello": "World!"}
35
 
36
  @app.post("/", response_model=list[Disease])
37
  async def predict(symptoms: Symptoms):
38
  query_embedding = model.encode(symptoms.query).astype('float')
39
  similarity_vectors = model.similarity(query_embedding, corpus)[0]
40
  scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
41
- # print("Similarity Vector Shape: ", similarity_vectors.shape)
42
- # print("Scores Shape: ", scores.shape)
43
- # print("Indicies Shape: ", indicies.shape)
44
- id_ = df.iloc[indicies].reset_index(drop=True)
45
- id_ = id_.drop_duplicates("label")
46
- scores = scores[id_.index]
47
- diseases = label_encoder.inverse_transform(id_.label.values)
48
- id_ = id_.label.values
49
- diseases = [dict({"id": value[0], "name": value[1], "score" : value[2]}) for value in zip(id_, diseases, scores)]
 
 
 
 
50
  return diseases
 
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import torch
8
 
9
+ corpus = pickle.load(open("./corpus/all_embeddings_disease.pickle", "rb"))
10
+ # label_encoder = pickle.load(open("./corpus/label_encoder.pickle", "rb"))
11
+ # model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
12
+ model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
13
+ df = pd.DataFrame(pickle.load(open("./corpus/y_all_disease.pickle", "rb")))
14
 
15
  app = FastAPI()
16
 
 
25
  class Disease(BaseModel):
26
  id: int
27
  name: str
28
+ url: str
29
  score: float
30
 
31
  class Symptoms(BaseModel):
32
  query: str
33
 
34
+ # @app.get("/")
35
+ # def greet_json():
36
+ # return {"Hello": "World!"}
37
 
38
  @app.post("/", response_model=list[Disease])
39
  async def predict(symptoms: Symptoms):
40
  query_embedding = model.encode(symptoms.query).astype('float')
41
  similarity_vectors = model.similarity(query_embedding, corpus)[0]
42
  scores, indicies = torch.topk(similarity_vectors, k=len(corpus))
43
+ # id_ = df.iloc[indicies].reset_index(drop=True)
44
+ df = df.iloc[indicies]
45
+ # id_ = id_.drop_duplicates("label")
46
+ df["scores"] = scores
47
+ # scores = scores[id_.index]
48
+ # diseases = label_encoder.inverse_transform(id_.label.values)
49
+ # id_ = id_.label.values
50
+ diseases = [dict({"id": value[0],
51
+ "name": value[1],
52
+ "score" : value[2],
53
+ "url" : value[3],
54
+ })
55
+ for value in zip(df.index, df["name"], df["scores"], df["url"])]
56
  return diseases
corpus/all_embeddings_disease.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11a003c5bc180aaff3d06b5f64ee28034512937629b635605a2bb56edd267ff9
3
+ size 4045987
corpus/y_all_disease.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11a003c5bc180aaff3d06b5f64ee28034512937629b635605a2bb56edd267ff9
3
+ size 4045987