|
import os |
|
from typing import Optional |
|
|
|
import torch |
|
from transformers.trainer import unwrap_model |
|
|
|
from .base_engine import TrainerForMMLLM |
|
|
|
|
|
class ShikraTrainer(TrainerForMMLLM): |
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
if getattr(self.args, 'tune_mm_mlp_adapter', False): |
|
|
|
_state_dict = state_dict |
|
if _state_dict is None: |
|
|
|
model_to_save = unwrap_model(self.model) |
|
_state_dict = model_to_save.state_dict() |
|
|
|
weight_to_save = {} |
|
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] |
|
for k, v in _state_dict.items(): |
|
if any(key_match in k for key_match in keys_to_match): |
|
weight_to_save[k] = v |
|
|
|
current_folder = output_dir.split('/')[-1] |
|
parent_folder = os.path.dirname(output_dir) |
|
if current_folder.startswith('checkpoint-'): |
|
mm_projector_folder = os.path.join(parent_folder, "mm_projector") |
|
os.makedirs(mm_projector_folder, exist_ok=True) |
|
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) |
|
else: |
|
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) |
|
super(ShikraTrainer, self)._save(output_dir, state_dict) |
|
|