ccm commited on
Commit
1f79ef9
·
verified ·
1 Parent(s): e550b0a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -11
main.py CHANGED
@@ -25,24 +25,27 @@ data.reset_index(inplace=True)
25
 
26
  # Create a FAISS index for fast similarity search
27
  indices = []
28
- metrics = [faiss.METRIC_INNER_PRODUCT]
 
29
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
30
  for metric in metrics:
31
- index = faiss.IndexFlatL2(len(data["embedding"][0]))
32
- index.metric_type = metric
33
- faiss.normalize_L2(vectors)
34
- index.train(vectors)
35
- index.add(vectors)
36
- indices.append(index)
 
 
37
 
38
  # Load the model for later use in embeddings
39
  model = sentence_transformers.SentenceTransformer("allenai-specter")
40
 
41
  # Define the search function
42
- def search(query: str, k: int):
43
  query = numpy.expand_dims(model.encode(query), axis=0)
44
  faiss.normalize_L2(query)
45
- D, I = indices[0].search(query, k)
46
  top_five = data.loc[I[0]]
47
  search_results = ""
48
 
@@ -88,8 +91,10 @@ with gradio.Blocks() as demo:
88
  )
89
  with gradio.Accordion("Settings", open=False):
90
  k = gradio.Number(10.0, label="Number of results", precision=0)
 
91
  results = gradio.Markdown()
92
- query.change(fn=search, inputs=[query, k], outputs=results)
93
- k.change(fn=search, inputs=[query, k], outputs=results)
 
94
 
95
  demo.launch(debug=True)
 
25
 
26
  # Create a FAISS index for fast similarity search
27
  indices = []
28
+ metrics = [faiss.METRIC_INNER_PRODUCT ,faiss.METRIC_L2]
29
+ normalization = [True, False]
30
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
31
  for metric in metrics:
32
+ for normal in normalization
33
+ index = faiss.IndexFlatL2(len(data["embedding"][0]))
34
+ index.metric_type = metric
35
+ if normal:
36
+ faiss.normalize_L2(vectors)
37
+ index.train(vectors)
38
+ index.add(vectors)
39
+ indices.append(index)
40
 
41
  # Load the model for later use in embeddings
42
  model = sentence_transformers.SentenceTransformer("allenai-specter")
43
 
44
  # Define the search function
45
+ def search(query: str, k: int, n: int):
46
  query = numpy.expand_dims(model.encode(query), axis=0)
47
  faiss.normalize_L2(query)
48
+ D, I = indices[n].search(query, k)
49
  top_five = data.loc[I[0]]
50
  search_results = ""
51
 
 
91
  )
92
  with gradio.Accordion("Settings", open=False):
93
  k = gradio.Number(10.0, label="Number of results", precision=0)
94
+ k = gradio.Radio([True, False], label="Number of results", precision=0)
95
  results = gradio.Markdown()
96
+ query.change(fn=search, inputs=[query, k, n], outputs=results)
97
+ k.change(fn=search, inputs=[query, k, n], outputs=results)
98
+ n.change(fn=search, inputs=[query, k, n], outputs=results)
99
 
100
  demo.launch(debug=True)