Spaces:
Sleeping
Sleeping
File size: 1,508 Bytes
a2d4cca a3b53c5 a2d4cca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import os
from haystack import Document
from haystack import Pipeline
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes.retriever.multimodal import MultiModalRetriever
from config import MODEL_DIM, MODEL_NAME
class MultiModelSearch:
def __init__(self):
self.document_stores = InMemoryDocumentStore(embedding_dim=MODEL_DIM)
document_directory = os.path.join(os.getcwd(),"data")
# fetch all images and write into haystack document
images = [
Document(content= f"{document_directory}/{filename}",content_type="image" )
for filename in os.listdir(document_directory)
if filename.lower().endswith(('jpg', 'jpeg', 'png'))
]
self.document_stores.write_documents(images)
self.retriever_text_to_image = MultiModalRetriever(
document_store= self.document_stores,
query_embedding_model= MODEL_NAME,
query_type="text",
document_embedding_models= {"image":MODEL_NAME},
)
self.document_stores.update_embeddings(retriever=self.retriever_text_to_image)
self.pipeline = Pipeline()
self.pipeline.add_node(component=self.retriever_text_to_image, name="retriever_text_to_image", inputs=["Query"])
def search(self,query, top_k = 3):
results = self.pipeline.run(query=query, params={"retriever_text_to_image": {"top_k":top_k}})
return sorted(results["documents"],key= lambda d:d.score, reverse=True)
|