medrag / medrag_multi_modal /retrieval /multi_modal_retrieval.py
mratanusarkar's picture
chore: improve doc + code formatting
f37090a
raw
history blame
5.6 kB
import os
from typing import Any, Optional
import weave
from byaldi import RAGMultiModalModel
from PIL import Image
import wandb
from ..utils import get_wandb_artifact
class MultiModalRetriever(weave.Model):
"""
MultiModalRetriever is a class that facilitates the retrieval of page images using ColPali.
This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
It can be initialized with a pre-trained model or from a specified W&B artifact. The class
also provides methods to index new data and to predict/retrieve documents based on a query.
!!! example "Indexing Data"
```python
import wandb
from medrag_multi_modal.retrieval import MultiModalRetriever
wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
retriever = MultiModalRetriever()
retriever.index(
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
weave_dataset_name="grays-anatomy-images:v0",
index_name="grays-anatomy",
)
```
!!! example "Retrieving Documents"
```python
import weave
import wandb
from medrag_multi_modal.retrieval import MultiModalRetriever
weave.init(project_name="ml-colabs/medrag-multi-modal")
retriever = MultiModalRetriever.from_artifact(
index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
metadata_dataset_name="grays-anatomy-images:v0",
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
)
retriever.predict(
query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
top_k=3,
)
```
Attributes:
model_name (str): The name of the model to be used for retrieval.
"""
model_name: str
_docs_retrieval_model: Optional[RAGMultiModalModel] = None
_metadata: Optional[dict] = None
_data_artifact_dir: Optional[str] = None
def __init__(
self,
model_name: str = "vidore/colpali-v1.2",
docs_retrieval_model: Optional[RAGMultiModalModel] = None,
data_artifact_dir: Optional[str] = None,
metadata_dataset_name: Optional[str] = None,
):
super().__init__(model_name=model_name)
self._docs_retrieval_model = (
docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
)
self._data_artifact_dir = data_artifact_dir
self._metadata = (
[dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
if metadata_dataset_name
else None
)
@classmethod
def from_artifact(
cls,
index_artifact_name: str,
metadata_dataset_name: str,
data_artifact_name: str,
):
index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
docs_retrieval_model = RAGMultiModalModel.from_index(
index_path=os.path.join(index_artifact_dir, "index")
)
return cls(
docs_retrieval_model=docs_retrieval_model,
metadata_dataset_name=metadata_dataset_name,
data_artifact_dir=data_artifact_dir,
)
def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
self._docs_retrieval_model.index(
input_path=data_artifact_dir,
index_name=index_name,
store_collection_with_index=False,
overwrite=True,
)
if wandb.run:
artifact = wandb.Artifact(
name=index_name,
type="colpali-index",
metadata={"weave_dataset_name": weave_dataset_name},
)
artifact.add_dir(
local_path=os.path.join(".byaldi", index_name), name="index"
)
artifact.save()
@weave.op()
def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
"""
Predicts and retrieves the top-k most relevant documents/images for a given query
using ColPali.
This function uses the document retrieval model to search for the most relevant
documents based on the provided query. It returns a list of dictionaries, each
containing the document image, document ID, and the relevance score.
Args:
query (str): The search query string.
top_k (int, optional): The number of top results to retrieve. Defaults to 10.
Returns:
list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
- "doc_image" (PIL.Image.Image): The image of the document.
- "doc_id" (str): The ID of the document.
- "score" (float): The relevance score of the document.
"""
results = self._docs_retrieval_model.search(query=query, k=top_k)
retrieved_results = []
for result in results:
retrieved_results.append(
{
"doc_image": Image.open(
os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
),
"doc_id": result["doc_id"],
"score": result["score"],
}
)
return retrieved_results