|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import copy |
|
from dataclasses import dataclass, field |
|
import json |
|
import logging |
|
import pathlib |
|
from typing import Dict, Optional, Sequence, List |
|
from webbrowser import get |
|
|
|
import torch |
|
|
|
import transformers |
|
import tokenizers |
|
|
|
|
|
|
|
from llava.train.llava_trainer import LLaVATrainer |
|
from llava.train.arguments import ModelArguments, TrainingArguments, DataArguments |
|
from llava.datasets.super_dataset import make_supervised_data_module, make_supervised_data_module_concatdataset |
|
from llava import conversation as conversation_lib |
|
from llava.model import * |
|
|
|
|
|
|
|
|
|
|
|
local_rank = None |
|
|
|
|
|
def rank0_print(*args): |
|
if local_rank == 0: |
|
print(*args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def maybe_zero_3(param, ignore_status=False, name=None): |
|
from deepspeed import zero |
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
if hasattr(param, "ds_id"): |
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
|
if not ignore_status: |
|
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") |
|
with zero.GatheredParameters([param]): |
|
param = param.data.detach().cpu().clone() |
|
else: |
|
param = param.detach().cpu().clone() |
|
return param |
|
|
|
|
|
|
|
def get_peft_state_maybe_zero_3(named_params, bias): |
|
if bias == "none": |
|
to_return = {k: t for k, t in named_params if "lora_" in k} |
|
elif bias == "all": |
|
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} |
|
elif bias == "lora_only": |
|
to_return = {} |
|
maybe_lora_bias = {} |
|
lora_bias_names = set() |
|
for k, t in named_params: |
|
if "lora_" in k: |
|
to_return[k] = t |
|
bias_name = k.split("lora_")[0] + "bias" |
|
lora_bias_names.add(bias_name) |
|
elif "bias" in k: |
|
maybe_lora_bias[k] = t |
|
for k, t in maybe_lora_bias: |
|
if bias_name in lora_bias_names: |
|
to_return[bias_name] = t |
|
else: |
|
raise NotImplementedError |
|
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} |
|
return to_return |
|
|
|
|
|
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): |
|
to_return = {k: t for k, t in named_params if "lora_" not in k} |
|
if require_grad_only: |
|
to_return = {k: t for k, t in to_return.items() if t.requires_grad} |
|
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} |
|
return to_return |
|
|
|
|
|
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
|
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} |
|
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} |
|
return to_return |
|
|
|
|
|
def find_all_linear_names(model): |
|
cls = torch.nn.Linear |
|
lora_module_names = set() |
|
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] |
|
for name, module in model.named_modules(): |
|
if any(mm_keyword in name for mm_keyword in multimodal_keywords): |
|
continue |
|
if isinstance(module, cls): |
|
names = name.split('.') |
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
|
|
|
if 'lm_head' in lora_module_names: |
|
lora_module_names.remove('lm_head') |
|
return list(lora_module_names) |
|
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, |
|
output_dir: str): |
|
"""Collects the state dict and dump to disk.""" |
|
|
|
if getattr(trainer.args, "tune_mm_mlp_adapter", False): |
|
|
|
keys_to_match = ['mm_projector', 'frame_position_encoding', 'adapter_module'] |
|
if getattr(trainer.args, "use_im_start_end", False): |
|
keys_to_match.extend(['embed_tokens', 'embed_in', 'wte']) |
|
if not getattr(trainer.args, 'freeze_qformer', True): |
|
keys_to_match.extend(['Qformer', 'query_tokens']) |
|
|
|
|
|
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) |
|
trainer.model.config.save_pretrained(output_dir) |
|
|
|
current_folder = output_dir.split('/')[-1] |
|
parent_folder = os.path.dirname(output_dir) |
|
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: |
|
if current_folder.startswith('checkpoint-'): |
|
mm_projector_folder = os.path.join(parent_folder, "mm_projector") |
|
os.makedirs(mm_projector_folder, exist_ok=True) |
|
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) |
|
else: |
|
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) |
|
return |
|
|
|
if trainer.deepspeed: |
|
torch.cuda.synchronize() |
|
trainer.save_model(output_dir) |
|
return |
|
|
|
state_dict = trainer.model.state_dict() |
|
if trainer.args.should_save: |
|
cpu_state_dict = { |
|
key: value.cpu() |
|
for key, value in state_dict.items() |
|
} |
|
del state_dict |
|
trainer._save(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict: Dict, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
): |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
|
|
def train(attn_implementation=None): |
|
global local_rank |
|
|
|
parser = transformers.HfArgumentParser( |
|
(ModelArguments, DataArguments, TrainingArguments)) |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
local_rank = training_args.local_rank |
|
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
|
|
|
bnb_model_from_pretrained_args = {} |
|
if training_args.bits in [4, 8]: |
|
from transformers import BitsAndBytesConfig |
|
bnb_model_from_pretrained_args.update(dict( |
|
device_map={"": training_args.device}, |
|
load_in_4bit=training_args.bits == 4, |
|
load_in_8bit=training_args.bits == 8, |
|
quantization_config=BitsAndBytesConfig( |
|
load_in_4bit=training_args.bits == 4, |
|
load_in_8bit=training_args.bits == 8, |
|
llm_int8_skip_modules=["mm_projector"], |
|
llm_int8_threshold=6.0, |
|
llm_int8_has_fp16_weight=False, |
|
bnb_4bit_compute_dtype=compute_dtype, |
|
bnb_4bit_use_double_quant=training_args.double_quant, |
|
bnb_4bit_quant_type=training_args.quant_type |
|
) |
|
)) |
|
|
|
if model_args.vision_tower is not None: |
|
if 'mpt' in model_args.model_name_or_path: |
|
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
|
config.attn_config['attn_impl'] = training_args.mpt_attn_impl |
|
model = LlavaMptForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
config=config, |
|
cache_dir=training_args.cache_dir, |
|
**bnb_model_from_pretrained_args |
|
) |
|
elif 'mistral' in model_args.model_name_or_path.lower(): |
|
model = LlavaMistralForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
attn_implementation=attn_implementation, |
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
**bnb_model_from_pretrained_args |
|
) |
|
elif 'gemma' in model_args.model_name_or_path.lower(): |
|
model = LlavaGemmaForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
attn_implementation=attn_implementation, |
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
**bnb_model_from_pretrained_args |
|
) |
|
elif 'thoth' in model_args.model_name_or_path.lower(): |
|
model = LlavaThothForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
attn_implementation=attn_implementation, |
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
**bnb_model_from_pretrained_args |
|
) |
|
else: |
|
model = LlavaLlamaForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
attn_implementation=attn_implementation, |
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
**bnb_model_from_pretrained_args |
|
) |
|
else: |
|
model = transformers.LlamaForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
attn_implementation=attn_implementation, |
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
**bnb_model_from_pretrained_args |
|
) |
|
model.config.use_cache = False |
|
|
|
if model_args.freeze_backbone: |
|
model.model.requires_grad_(False) |
|
|
|
if training_args.bits in [4, 8]: |
|
from peft import prepare_model_for_kbit_training |
|
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) |
|
|
|
if training_args.gradient_checkpointing: |
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
if training_args.lora_enable: |
|
from peft import LoraConfig, get_peft_model |
|
lora_config = LoraConfig( |
|
r=training_args.lora_r, |
|
lora_alpha=training_args.lora_alpha, |
|
target_modules=find_all_linear_names(model), |
|
lora_dropout=training_args.lora_dropout, |
|
bias=training_args.lora_bias, |
|
task_type="CAUSAL_LM", |
|
) |
|
if training_args.bits == 16: |
|
if training_args.bf16: |
|
model.to(torch.bfloat16) |
|
if training_args.fp16: |
|
model.to(torch.float16) |
|
rank0_print("Adding LoRA adapters...") |
|
model = get_peft_model(model, lora_config) |
|
|
|
if 'mpt' in model_args.model_name_or_path: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right" |
|
) |
|
elif 'thoth' in model_args.model_name_or_path: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
use_fast=True |
|
) |
|
else: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
|
|
if model_args.version == "v0": |
|
if tokenizer.pad_token is None: |
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=dict(pad_token="[PAD]"), |
|
tokenizer=tokenizer, |
|
model=model, |
|
) |
|
elif model_args.version == "v0.5": |
|
tokenizer.pad_token = tokenizer.unk_token |
|
else: |
|
if 'thoth' not in model_args.model_name_or_path: |
|
tokenizer.pad_token = tokenizer.unk_token |
|
if model_args.version in conversation_lib.conv_templates: |
|
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] |
|
else: |
|
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] |
|
|
|
model_args.max_num_segments = data_args.num_segments |
|
if model_args.vision_tower is not None: |
|
model.get_model().initialize_vision_modules( |
|
model_args=model_args, |
|
fsdp=training_args.fsdp |
|
) |
|
|
|
vision_tower = model.get_vision_tower() |
|
|
|
data_args.image_processor = vision_tower.image_processor |
|
data_args.is_multimodal = True |
|
|
|
model.config.image_aspect_ratio = data_args.image_aspect_ratio |
|
model.config.tokenizer_padding_side = tokenizer.padding_side |
|
model.config.tokenizer_model_max_length = tokenizer.model_max_length |
|
|
|
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter |
|
if model_args.tune_mm_mlp_adapter: |
|
model.requires_grad_(False) |
|
for p in model.get_model().mm_projector.parameters(): |
|
p.requires_grad = True |
|
|
|
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter |
|
if training_args.freeze_mm_mlp_adapter: |
|
for p in model.get_model().mm_projector.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
if model.get_model().get_frame_position_encoding(): |
|
model.get_frame_position_encoding().weight.requires_grad = True |
|
|
|
if training_args.bits in [4, 8]: |
|
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) |
|
|
|
model.config.mm_use_start_end = data_args.mm_use_start_end = model_args.mm_use_start_end |
|
model.config.mm_projector_lr = training_args.mm_projector_lr |
|
model.config.lora_lr = training_args.lora_lr |
|
training_args.use_im_start_end = model_args.mm_use_start_end |
|
model.config.mm_use_patch_token = model_args.mm_use_patch_token |
|
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) |
|
|
|
|
|
if getattr(training_args, "freeze_vision_encoder", True): |
|
for p in model.get_vision_tower().parameters(): |
|
p.requires_grad = False |
|
else: |
|
for p in model.get_vision_tower().parameters(): |
|
p.requires_grad = True |
|
|
|
|
|
if getattr(model_args, 'qformer_model_path', None): |
|
if getattr(training_args, "freeze_qformer", True): |
|
for p in model.get_qformer().parameters(): |
|
p.requires_grad = False |
|
for p in model.get_ln_vision().parameters(): |
|
p.requires_grad = False |
|
model.get_query_tokens().requires_grad = False |
|
else: |
|
for p in model.get_qformer().parameters(): |
|
p.requires_grad = True |
|
for p in model.get_ln_vision().parameters(): |
|
p.requires_grad = True |
|
model.get_query_tokens().requires_grad = True |
|
|
|
if getattr(model_args, 'adapter_module_name', None): |
|
model.get_adapter_module().freeze_adapter_module(getattr(training_args, "freeze_adapter", False)) |
|
|
|
|
|
|
|
|
|
trainable_params = [name for (name, param) in model.named_parameters() if param.requires_grad == True] |
|
|
|
rank0_print(f"==> Trainable parameters: {trainable_params}") |
|
|
|
if training_args.bits in [4, 8]: |
|
from peft.tuners.lora import LoraLayer |
|
for name, module in model.named_modules(): |
|
if isinstance(module, LoraLayer): |
|
if training_args.bf16: |
|
module = module.to(torch.bfloat16) |
|
if 'norm' in name: |
|
module = module.to(torch.float32) |
|
if 'lm_head' in name or 'embed_tokens' in name: |
|
if hasattr(module, 'weight'): |
|
if training_args.bf16 and module.weight.dtype == torch.float32: |
|
module = module.to(torch.bfloat16) |
|
|
|
data_args.image_grid_pinpoints = model_args.image_grid_pinpoints |
|
if not training_args.group_by_modality_length: |
|
data_module = make_supervised_data_module(tokenizer=tokenizer, |
|
data_args=data_args, |
|
num_workers=training_args.dataloader_num_workers) |
|
else: |
|
data_module = make_supervised_data_module_concatdataset(tokenizer=tokenizer, |
|
data_args=data_args, |
|
num_workers=training_args.dataloader_num_workers) |
|
|
|
trainer = LLaVATrainer(model=model, |
|
tokenizer=tokenizer, |
|
args=training_args, |
|
**data_module) |
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
trainer.train() |
|
trainer.save_state() |
|
|
|
model.config.use_cache = True |
|
|
|
if training_args.lora_enable: |
|
state_dict = get_peft_state_maybe_zero_3( |
|
model.named_parameters(), training_args.lora_bias |
|
) |
|
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( |
|
model.named_parameters() |
|
) |
|
if training_args.local_rank == 0 or training_args.local_rank == -1: |
|
model.config.save_pretrained(training_args.output_dir) |
|
model.save_pretrained(training_args.output_dir, state_dict=state_dict) |
|
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) |
|
else: |
|
safe_save_model_for_hf_trainer(trainer=trainer, |
|
output_dir=training_args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|