Update modelling_magi.py
Browse files- modelling_magi.py +2 -2
 
    	
        modelling_magi.py
    CHANGED
    
    | 
         @@ -181,7 +181,7 @@ class MagiModel(PreTrainedModel): 
     | 
|
| 181 | 
         | 
| 182 | 
         
             
                    return crop_embeddings_for_batch
         
     | 
| 183 | 
         | 
| 184 | 
         
            -
                def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32):
         
     | 
| 185 | 
         
             
                    assert not self.config.disable_ocr
         
     | 
| 186 | 
         
             
                    move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
         
     | 
| 187 | 
         | 
| 
         @@ -207,7 +207,7 @@ class MagiModel(PreTrainedModel): 
     | 
|
| 207 | 
         
             
                        pbar = range(0, len(crops_per_image), batch_size)
         
     | 
| 208 | 
         
             
                    for i in pbar:
         
     | 
| 209 | 
         
             
                        crops = crops_per_image[i:i+batch_size]
         
     | 
| 210 | 
         
            -
                        generated_ids = self.ocr_model.generate(crops)
         
     | 
| 211 | 
         
             
                        generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
         
     | 
| 212 | 
         
             
                        all_generated_texts.extend(generated_texts)
         
     | 
| 213 | 
         | 
| 
         | 
|
| 181 | 
         | 
| 182 | 
         
             
                    return crop_embeddings_for_batch
         
     | 
| 183 | 
         | 
| 184 | 
         
            +
                def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32, max_new_tokens=64):
         
     | 
| 185 | 
         
             
                    assert not self.config.disable_ocr
         
     | 
| 186 | 
         
             
                    move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
         
     | 
| 187 | 
         | 
| 
         | 
|
| 207 | 
         
             
                        pbar = range(0, len(crops_per_image), batch_size)
         
     | 
| 208 | 
         
             
                    for i in pbar:
         
     | 
| 209 | 
         
             
                        crops = crops_per_image[i:i+batch_size]
         
     | 
| 210 | 
         
            +
                        generated_ids = self.ocr_model.generate(crops, max_new_tokens=max_new_tokens)
         
     | 
| 211 | 
         
             
                        generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
         
     | 
| 212 | 
         
             
                        all_generated_texts.extend(generated_texts)
         
     | 
| 213 | 
         |