Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						abd20d0
	
1
								Parent(s):
							
							9a6c015
								
update: colpali index syncs with wandb artifact
Browse files
    	
        .gitignore
    CHANGED
    
    | @@ -7,4 +7,6 @@ cursor_prompt.txt | |
| 7 | 
             
            test.py
         | 
| 8 | 
             
            **.pdf
         | 
| 9 | 
             
            images/
         | 
| 10 | 
            -
            wandb/
         | 
|  | |
|  | 
|  | |
| 7 | 
             
            test.py
         | 
| 8 | 
             
            **.pdf
         | 
| 9 | 
             
            images/
         | 
| 10 | 
            +
            wandb/
         | 
| 11 | 
            +
            .byaldi/
         | 
| 12 | 
            +
            artifacts/
         | 
    	
        medrag_multi_modal/document_loader/load_image.py
    CHANGED
    
    | @@ -3,11 +3,11 @@ import os | |
| 3 | 
             
            from typing import Optional
         | 
| 4 |  | 
| 5 | 
             
            import rich
         | 
|  | |
| 6 | 
             
            import weave
         | 
| 7 | 
             
            from pdf2image.pdf2image import convert_from_path
         | 
| 8 | 
             
            from PIL import Image
         | 
| 9 |  | 
| 10 | 
            -
            import wandb
         | 
| 11 | 
             
            from medrag_multi_modal.document_loader.load_text import TextLoader
         | 
| 12 |  | 
| 13 |  | 
|  | |
| 3 | 
             
            from typing import Optional
         | 
| 4 |  | 
| 5 | 
             
            import rich
         | 
| 6 | 
            +
            import wandb
         | 
| 7 | 
             
            import weave
         | 
| 8 | 
             
            from pdf2image.pdf2image import convert_from_path
         | 
| 9 | 
             
            from PIL import Image
         | 
| 10 |  | 
|  | |
| 11 | 
             
            from medrag_multi_modal.document_loader.load_text import TextLoader
         | 
| 12 |  | 
| 13 |  | 
    	
        medrag_multi_modal/retrieval/__init__.py
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
             
            from .multi_modal_retrieval import MultiModalRetriever
         | 
| 2 |  | 
| 3 | 
            -
            __all__ = ["MultiModalRetriever"]
         | 
|  | |
| 1 | 
             
            from .multi_modal_retrieval import MultiModalRetriever
         | 
| 2 |  | 
| 3 | 
            +
            __all__ = ["MultiModalRetriever"]
         | 
    	
        medrag_multi_modal/retrieval/multi_modal_retrieval.py
    CHANGED
    
    | @@ -1,22 +1,39 @@ | |
|  | |
|  | |
|  | |
| 1 | 
             
            import weave
         | 
| 2 | 
             
            from byaldi import RAGMultiModalModel
         | 
| 3 | 
            -
            import wandb
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
             
            class MultiModalRetriever(weave.Model):
         | 
| 7 | 
             
                model_name: str
         | 
| 8 | 
             
                _docs_retrieval_model: RAGMultiModalModel
         | 
| 9 | 
            -
             | 
| 10 | 
             
                def __init__(self, model_name: str = "vidore/colpali-v1.2"):
         | 
| 11 | 
             
                    super().__init__(model_name=model_name)
         | 
| 12 | 
             
                    self._docs_retrieval_model = RAGMultiModalModel.from_pretrained(self.model_name)
         | 
| 13 | 
            -
             | 
| 14 | 
             
                def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
         | 
| 15 | 
             
                    if wandb.run:
         | 
| 16 | 
            -
                        artifact = wandb.use_artifact(data_artifact_name, type= | 
| 17 | 
             
                        artifact_dir = artifact.download()
         | 
| 18 | 
             
                    else:
         | 
| 19 | 
             
                        api = wandb.Api()
         | 
| 20 | 
             
                        artifact = api.artifact(data_artifact_name)
         | 
| 21 | 
             
                        artifact_dir = artifact.download()
         | 
| 22 | 
            -
                    self._docs_retrieval_model.index( | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import wandb
         | 
| 4 | 
             
            import weave
         | 
| 5 | 
             
            from byaldi import RAGMultiModalModel
         | 
|  | |
| 6 |  | 
| 7 |  | 
| 8 | 
             
            class MultiModalRetriever(weave.Model):
         | 
| 9 | 
             
                model_name: str
         | 
| 10 | 
             
                _docs_retrieval_model: RAGMultiModalModel
         | 
| 11 | 
            +
             | 
| 12 | 
             
                def __init__(self, model_name: str = "vidore/colpali-v1.2"):
         | 
| 13 | 
             
                    super().__init__(model_name=model_name)
         | 
| 14 | 
             
                    self._docs_retrieval_model = RAGMultiModalModel.from_pretrained(self.model_name)
         | 
| 15 | 
            +
             | 
| 16 | 
             
                def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
         | 
| 17 | 
             
                    if wandb.run:
         | 
| 18 | 
            +
                        artifact = wandb.use_artifact(data_artifact_name, type="dataset")
         | 
| 19 | 
             
                        artifact_dir = artifact.download()
         | 
| 20 | 
             
                    else:
         | 
| 21 | 
             
                        api = wandb.Api()
         | 
| 22 | 
             
                        artifact = api.artifact(data_artifact_name)
         | 
| 23 | 
             
                        artifact_dir = artifact.download()
         | 
| 24 | 
            +
                    self._docs_retrieval_model.index(
         | 
| 25 | 
            +
                        input_path=artifact_dir,
         | 
| 26 | 
            +
                        index_name=index_name,
         | 
| 27 | 
            +
                        store_collection_with_index=False,
         | 
| 28 | 
            +
                        overwrite=True,
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
                    if wandb.run:
         | 
| 31 | 
            +
                        artifact = wandb.Artifact(
         | 
| 32 | 
            +
                            name=index_name,
         | 
| 33 | 
            +
                            type="colpali-index",
         | 
| 34 | 
            +
                            metadata={"weave_dataset_name": weave_dataset_name},
         | 
| 35 | 
            +
                        )
         | 
| 36 | 
            +
                        artifact.add_dir(
         | 
| 37 | 
            +
                            local_path=os.path.join(".byaldi", index_name), name="index"
         | 
| 38 | 
            +
                        )
         | 
| 39 | 
            +
                        artifact.save()
         | 
