HUANG-Stephanie commited on
Commit
2b7ded4
·
verified ·
1 Parent(s): a59e0f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -15,7 +15,7 @@ from colpali_engine.models.paligemma_colbert_architecture import ColPali
15
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
16
  from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
17
 
18
- def search(query: str, ds, images):
19
  qs = []
20
  with torch.no_grad():
21
  batch_query = process_queries(processor, [query], mock_image)
@@ -26,8 +26,14 @@ def search(query: str, ds, images):
26
  # run evaluation
27
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
28
  scores = retriever_evaluator.evaluate(qs, ds)
29
- best_page = int(scores.argmax(axis=1).item())
30
- return f"The most relevant page is {best_page}", images[best_page]
 
 
 
 
 
 
31
 
32
 
33
  def index(file, ds):
@@ -83,9 +89,9 @@ with gr.Blocks() as demo:
83
  search_button = gr.Button("🔍 Search")
84
  message2 = gr.Textbox("Query not yet set")
85
  output_img = gr.Image()
 
86
 
87
- search_button.click(search, inputs=[query, embeds, imgs], outputs=[message2, output_img])
88
-
89
 
90
  if __name__ == "__main__":
91
  demo.queue(max_size=10).launch(debug=True)
 
15
  from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
16
  from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
17
 
18
+ def search(query: str, ds, images, k):
19
  qs = []
20
  with torch.no_grad():
21
  batch_query = process_queries(processor, [query], mock_image)
 
26
  # run evaluation
27
  retriever_evaluator = CustomEvaluator(is_multi_vector=True)
28
  scores = retriever_evaluator.evaluate(qs, ds)
29
+
30
+ top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
31
+
32
+ results = []
33
+ for idx in top_k_indices:
34
+ results.append((images[idx], f"Page {idx}"))
35
+
36
+ return results
37
 
38
 
39
  def index(file, ds):
 
89
  search_button = gr.Button("🔍 Search")
90
  message2 = gr.Textbox("Query not yet set")
91
  output_img = gr.Image()
92
+ k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
93
 
94
+ search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[message2, output_img])
 
95
 
96
  if __name__ == "__main__":
97
  demo.queue(max_size=10).launch(debug=True)