Spaces:
Running
Running
File size: 6,138 Bytes
abd20d0 3bd446c abd20d0 7df75ff 353a440 3bd446c 353a440 d197e7f 4ea2b30 d197e7f 7df75ff 21537b7 d197e7f 21537b7 e0aff18 d197e7f e0aff18 d197e7f 8f9e28d f673b7f 8f9e28d f673b7f 8f9e28d f673b7f d197e7f 21537b7 d197e7f 21537b7 d197e7f e0aff18 d197e7f 8f9e28d f673b7f 8f9e28d f673b7f 8f9e28d f673b7f d197e7f 21537b7 e0aff18 d197e7f c063304 d197e7f e0aff18 7df75ff 3bd446c d197e7f abd20d0 d197e7f 3bd446c d197e7f 7df75ff 3bd446c d197e7f c063304 d197e7f 3bd446c d197e7f abd20d0 7df75ff d197e7f abd20d0 d197e7f abd20d0 d197e7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
from typing import TYPE_CHECKING, Any, Optional
import weave
if TYPE_CHECKING:
from byaldi import RAGMultiModalModel
from PIL import Image
import wandb
from ..utils import get_wandb_artifact
class CalPaliRetriever(weave.Model):
"""
CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali.
This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
It can be initialized with a pre-trained model or from a specified W&B artifact. The class
also provides methods to index new data and to predict/retrieve documents based on a query.
!!! example "Indexing Data"
First you need to install `Byaldi` library by Answer.ai.
```bash
uv pip install Byaldi>=0.0.5
```
Next, you can index the data by running the following code:
```python
import wandb
from medrag_multi_modal.retrieval import CalPaliRetriever
wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
retriever = CalPaliRetriever()
retriever.index(
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
weave_dataset_name="grays-anatomy-images:v0",
index_name="grays-anatomy",
)
```
!!! example "Retrieving Documents"
First you need to install `Byaldi` library by Answer.ai.
```bash
uv pip install Byaldi>=0.0.5
```
Next, you can retrieve the documents by running the following code:
```python
import weave
import wandb
from medrag_multi_modal.retrieval import CalPaliRetriever
weave.init(project_name="ml-colabs/medrag-multi-modal")
retriever = CalPaliRetriever.from_wandb_artifact(
index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
metadata_dataset_name="grays-anatomy-images:v0",
data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
)
retriever.predict(
query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
top_k=3,
)
```
Attributes:
model_name (str): The name of the model to be used for retrieval.
"""
model_name: str
_docs_retrieval_model: Optional["RAGMultiModalModel"] = None
_metadata: Optional[dict] = None
_data_artifact_dir: Optional[str] = None
def __init__(
self,
model_name: str = "vidore/colpali-v1.2",
docs_retrieval_model: Optional["RAGMultiModalModel"] = None,
data_artifact_dir: Optional[str] = None,
metadata_dataset_name: Optional[str] = None,
):
super().__init__(model_name=model_name)
from byaldi import RAGMultiModalModel
self._docs_retrieval_model = (
docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
)
self._data_artifact_dir = data_artifact_dir
self._metadata = (
[dict(row) for row in weave.ref(metadata_dataset_name).get().rows]
if metadata_dataset_name
else None
)
@classmethod
def from_wandb_artifact(
cls,
index_artifact_name: str,
metadata_dataset_name: str,
data_artifact_name: str,
):
from byaldi import RAGMultiModalModel
index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
docs_retrieval_model = RAGMultiModalModel.from_index(
index_path=os.path.join(index_artifact_dir, "index")
)
return cls(
docs_retrieval_model=docs_retrieval_model,
metadata_dataset_name=metadata_dataset_name,
data_artifact_dir=data_artifact_dir,
)
def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
self._docs_retrieval_model.index(
input_path=data_artifact_dir,
index_name=index_name,
store_collection_with_index=False,
overwrite=True,
)
if wandb.run:
artifact = wandb.Artifact(
name=index_name,
type="colpali-index",
metadata={"weave_dataset_name": weave_dataset_name},
)
artifact.add_dir(
local_path=os.path.join(".byaldi", index_name), name="index"
)
artifact.save()
@weave.op()
def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
"""
Predicts and retrieves the top-k most relevant documents/images for a given query
using ColPali.
This function uses the document retrieval model to search for the most relevant
documents based on the provided query. It returns a list of dictionaries, each
containing the document image, document ID, and the relevance score.
Args:
query (str): The search query string.
top_k (int, optional): The number of top results to retrieve. Defaults to 10.
Returns:
list[dict[str, Any]]: A list of dictionaries where each dictionary contains:
- "doc_image" (PIL.Image.Image): The image of the document.
- "doc_id" (str): The ID of the document.
- "score" (float): The relevance score of the document.
"""
results = self._docs_retrieval_model.search(query=query, k=top_k)
retrieved_results = []
for result in results:
retrieved_results.append(
{
"doc_image": Image.open(
os.path.join(self._data_artifact_dir, f"{result['doc_id']}.png")
),
"doc_id": result["doc_id"],
"score": result["score"],
}
)
return retrieved_results
|