lisa-on-cuda / model /llava /train /llava_trainer.py
x-lai
Release training script
3d9fba4
raw
history blame
1.95 kB
import os
from typing import Dict, Optional, Sequence
import torch
import torch.nn as nn
from transformers import Trainer
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class LLaVATrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self.model)
_state_dict = model_to_save.state_dict()
weight_to_save = {}
keys_to_match = ["mm_projector", "embed_tokens", "embed_in"]
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save,
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
)
else:
torch.save(
weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
)
super(LLaVATrainer, self)._save(output_dir, state_dict)