File size: 1,508 Bytes
a2d4cca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3b53c5
a2d4cca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
from haystack import Document
from haystack import Pipeline
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes.retriever.multimodal import MultiModalRetriever
from config import MODEL_DIM, MODEL_NAME

class MultiModelSearch:
    def __init__(self):
        self.document_stores = InMemoryDocumentStore(embedding_dim=MODEL_DIM)
        document_directory = os.path.join(os.getcwd(),"data")
        # fetch all images and write into haystack document
        images = [
            Document(content= f"{document_directory}/{filename}",content_type="image" )
            for filename in os.listdir(document_directory)
            if filename.lower().endswith(('jpg', 'jpeg', 'png'))
        ]

        self.document_stores.write_documents(images)
        self.retriever_text_to_image = MultiModalRetriever(
            document_store= self.document_stores,
            query_embedding_model= MODEL_NAME,
            query_type="text",
            document_embedding_models= {"image":MODEL_NAME},
)
        
        self.document_stores.update_embeddings(retriever=self.retriever_text_to_image)

        self.pipeline = Pipeline()
        self.pipeline.add_node(component=self.retriever_text_to_image, name="retriever_text_to_image", inputs=["Query"])

    def search(self,query, top_k = 3):
        results = self.pipeline.run(query=query, params={"retriever_text_to_image": {"top_k":top_k}})
        return sorted(results["documents"],key= lambda d:d.score, reverse=True)