File size: 5,579 Bytes
abd20d0
d197e7f
abd20d0
7df75ff
 
d197e7f
 
4ea2b30
 
d197e7f
7df75ff
 
21537b7
d197e7f
21537b7
e0aff18
d197e7f
 
 
e0aff18
d197e7f
 
 
21537b7
d197e7f
 
21537b7
d197e7f
 
 
 
 
 
e0aff18
d197e7f
 
 
 
 
21537b7
e0aff18
d197e7f
21537b7
d197e7f
 
 
 
 
 
 
 
 
 
 
 
 
e0aff18
7df75ff
d197e7f
 
 
abd20d0
d197e7f
 
 
 
 
 
 
7df75ff
d197e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd20d0
7df75ff
d197e7f
abd20d0
d197e7f
abd20d0
 
 
 
 
 
 
 
 
 
 
 
 
 
d197e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
from typing import Any, Optional

import weave
from byaldi import RAGMultiModalModel
from PIL import Image

import wandb

from ..utils import get_wandb_artifact


class CalPaliRetriever(weave.Model):
    """
    CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali.

    This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
    It can be initialized with a pre-trained model or from a specified W&B artifact. The class
    also provides methods to index new data and to predict/retrieve documents based on a query.

    !!! example "Indexing Data"
        ```python
        import wandb
        from medrag_multi_modal.retrieval import CalPaliRetriever

        wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
        retriever = CalPaliRetriever()
        retriever.index(
            data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
            weave_dataset_name="grays-anatomy-images:v0",
            index_name="grays-anatomy",
        )
        ```

    !!! example "Retrieving Documents"
        ```python
        import weave

        import wandb
        from medrag_multi_modal.retrieval import CalPaliRetriever

        weave.init(project_name="ml-colabs/medrag-multi-modal")
        retriever = CalPaliRetriever.from_artifact(
            index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
            metadata_dataset_name="grays-anatomy-images:v0",
            data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
        )
        retriever.predict(
            query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
            top_k=3,
        )
        ```

    Attributes:
        model_name (str): The name of the model to be used for retrieval.
    """

    model_name: str
    _docs_retrieval_model: Optional[RAGMultiModalModel] = None
    _metadata: Optional[dict] = None
    _data_artifact_dir: Optional[str] = None

    def __init__(
        self,
        model_name: str = "vidore/colpali-v1.2",
        docs_retrieval_model: Optional[RAGMultiModalModel] = None,
        data_artifact_dir: Optional[str] = None,
        metadata_dataset_name: Optional[str] = None,
    ):
        super().__init__(model_name=model_name)
        self._docs_retrieval_model = (
            docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
        )
        self._data_artifact_dir = data_artifact_dir
        self._metadata = (
            [dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
            if metadata_dataset_name
            else None
        )

    @classmethod
    def from_artifact(
        cls,
        index_artifact_name: str,
        metadata_dataset_name: str,
        data_artifact_name: str,
    ):
        index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
        data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
        docs_retrieval_model = RAGMultiModalModel.from_index(
            index_path=os.path.join(index_artifact_dir, "index")
        )
        return cls(
            docs_retrieval_model=docs_retrieval_model,
            metadata_dataset_name=metadata_dataset_name,
            data_artifact_dir=data_artifact_dir,
        )

    def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
        data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
        self._docs_retrieval_model.index(
            input_path=data_artifact_dir,
            index_name=index_name,
            store_collection_with_index=False,
            overwrite=True,
        )
        if wandb.run:
            artifact = wandb.Artifact(
                name=index_name,
                type="colpali-index",
                metadata={"weave_dataset_name": weave_dataset_name},
            )
            artifact.add_dir(
                local_path=os.path.join(".byaldi", index_name), name="index"
            )
            artifact.save()

    @weave.op()
    def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
        """
        Predicts and retrieves the top-k most relevant documents/images for a given query
        using ColPali.

        This function uses the document retrieval model to search for the most relevant
        documents based on the provided query. It returns a list of dictionaries, each
        containing the document image, document ID, and the relevance score.

        Args:
            query (str): The search query string.
            top_k (int, optional): The number of top results to retrieve. Defaults to 10.

        Returns:
            list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
                - "doc_image" (PIL.Image.Image): The image of the document.
                - "doc_id" (str): The ID of the document.
                - "score" (float): The relevance score of the document.
        """
        results = self._docs_retrieval_model.search(query=query, k=top_k)
        retrieved_results = []
        for result in results:
            retrieved_results.append(
                {
                    "doc_image": Image.open(
                        os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
                    ),
                    "doc_id": result["doc_id"],
                    "score": result["score"],
                }
            )
        return retrieved_results