Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ from colpali_engine.models.paligemma_colbert_architecture import ColPali
|
|
15 |
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
|
16 |
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
|
17 |
|
18 |
-
def search(query: str, ds, images):
|
19 |
qs = []
|
20 |
with torch.no_grad():
|
21 |
batch_query = process_queries(processor, [query], mock_image)
|
@@ -26,8 +26,14 @@ def search(query: str, ds, images):
|
|
26 |
# run evaluation
|
27 |
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
|
28 |
scores = retriever_evaluator.evaluate(qs, ds)
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
def index(file, ds):
|
@@ -83,9 +89,9 @@ with gr.Blocks() as demo:
|
|
83 |
search_button = gr.Button("🔍 Search")
|
84 |
message2 = gr.Textbox("Query not yet set")
|
85 |
output_img = gr.Image()
|
|
|
86 |
|
87 |
-
search_button.click(search, inputs=[query, embeds, imgs], outputs=[message2, output_img])
|
88 |
-
|
89 |
|
90 |
if __name__ == "__main__":
|
91 |
demo.queue(max_size=10).launch(debug=True)
|
|
|
15 |
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
|
16 |
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
|
17 |
|
18 |
+
def search(query: str, ds, images, k):
|
19 |
qs = []
|
20 |
with torch.no_grad():
|
21 |
batch_query = process_queries(processor, [query], mock_image)
|
|
|
26 |
# run evaluation
|
27 |
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
|
28 |
scores = retriever_evaluator.evaluate(qs, ds)
|
29 |
+
|
30 |
+
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
|
31 |
+
|
32 |
+
results = []
|
33 |
+
for idx in top_k_indices:
|
34 |
+
results.append((images[idx], f"Page {idx}"))
|
35 |
+
|
36 |
+
return results
|
37 |
|
38 |
|
39 |
def index(file, ds):
|
|
|
89 |
search_button = gr.Button("🔍 Search")
|
90 |
message2 = gr.Textbox("Query not yet set")
|
91 |
output_img = gr.Image()
|
92 |
+
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
|
93 |
|
94 |
+
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[message2, output_img])
|
|
|
95 |
|
96 |
if __name__ == "__main__":
|
97 |
demo.queue(max_size=10).launch(debug=True)
|