Add supports_gradient_checkpointing
Browse files- configuration_internvl_chat.py +2 -0
 - modeling_intern_vit.py +1 -0
 - modeling_internvl_chat.py +11 -0
 
    	
        configuration_internvl_chat.py
    CHANGED
    
    | 
         @@ -61,6 +61,8 @@ class InternVLChatConfig(PretrainedConfig): 
     | 
|
| 61 | 
         
             
                    self.ps_version = ps_version  # pixel shuffle version
         
     | 
| 62 | 
         
             
                    self.min_dynamic_patch = min_dynamic_patch
         
     | 
| 63 | 
         
             
                    self.max_dynamic_patch = max_dynamic_patch
         
     | 
| 
         | 
|
| 
         | 
|
| 64 | 
         | 
| 65 | 
         
             
                    logger.info(f'vision_select_layer: {self.select_layer}')
         
     | 
| 66 | 
         
             
                    logger.info(f'ps_version: {self.ps_version}')
         
     | 
| 
         | 
|
| 61 | 
         
             
                    self.ps_version = ps_version  # pixel shuffle version
         
     | 
| 62 | 
         
             
                    self.min_dynamic_patch = min_dynamic_patch
         
     | 
| 63 | 
         
             
                    self.max_dynamic_patch = max_dynamic_patch
         
     | 
| 64 | 
         
            +
                    # By default, we use tie_word_embeddings=False for models of all sizes.
         
     | 
| 65 | 
         
            +
                    self.tie_word_embeddings = self.llm_config.tie_word_embeddings
         
     | 
| 66 | 
         | 
| 67 | 
         
             
                    logger.info(f'vision_select_layer: {self.select_layer}')
         
     | 
| 68 | 
         
             
                    logger.info(f'ps_version: {self.ps_version}')
         
     | 
    	
        modeling_intern_vit.py
    CHANGED
    
    | 
         @@ -364,6 +364,7 @@ class InternVisionEncoder(nn.Module): 
     | 
|
| 364 | 
         
             
            class InternVisionModel(PreTrainedModel):
         
     | 
| 365 | 
         
             
                main_input_name = 'pixel_values'
         
     | 
| 366 | 
         
             
                _supports_flash_attn_2 = True
         
     | 
| 
         | 
|
| 367 | 
         
             
                config_class = InternVisionConfig
         
     | 
| 368 | 
         
             
                _no_split_modules = ['InternVisionEncoderLayer']
         
     | 
| 369 | 
         | 
| 
         | 
|
| 364 | 
         
             
            class InternVisionModel(PreTrainedModel):
         
     | 
| 365 | 
         
             
                main_input_name = 'pixel_values'
         
     | 
| 366 | 
         
             
                _supports_flash_attn_2 = True
         
     | 
| 367 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 368 | 
         
             
                config_class = InternVisionConfig
         
     | 
| 369 | 
         
             
                _no_split_modules = ['InternVisionEncoderLayer']
         
     | 
| 370 | 
         | 
    	
        modeling_internvl_chat.py
    CHANGED
    
    | 
         @@ -36,6 +36,7 @@ class InternVLChatModel(PreTrainedModel): 
     | 
|
| 36 | 
         
             
                main_input_name = 'pixel_values'
         
     | 
| 37 | 
         
             
                base_model_prefix = 'language_model'
         
     | 
| 38 | 
         
             
                _supports_flash_attn_2 = True
         
     | 
| 
         | 
|
| 39 | 
         
             
                _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
         
     | 
| 40 | 
         | 
| 41 | 
         
             
                def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
         
     | 
| 
         @@ -343,3 +344,13 @@ class InternVLChatModel(PreTrainedModel): 
     | 
|
| 343 | 
         
             
                    )
         
     | 
| 344 | 
         | 
| 345 | 
         
             
                    return outputs
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 36 | 
         
             
                main_input_name = 'pixel_values'
         
     | 
| 37 | 
         
             
                base_model_prefix = 'language_model'
         
     | 
| 38 | 
         
             
                _supports_flash_attn_2 = True
         
     | 
| 39 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 40 | 
         
             
                _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
         
     | 
| 41 | 
         | 
| 42 | 
         
             
                def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
         
     | 
| 
         | 
|
| 344 | 
         
             
                    )
         
     | 
| 345 | 
         | 
| 346 | 
         
             
                    return outputs
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                @property
         
     | 
| 349 | 
         
            +
                def lm_head(self):
         
     | 
| 350 | 
         
            +
                    return self.language_model.get_output_embeddings()
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 353 | 
         
            +
                    return self.language_model.get_input_embeddings()
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                def get_output_embeddings(self):
         
     | 
| 356 | 
         
            +
                    return self.language_model.get_output_embeddings()
         
     |