# Copyright 2025 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from types import MethodType from typing import TYPE_CHECKING, Any, Dict import torch from peft import PeftModel from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from ..extras import logging from ..extras.misc import infer_optim_dtype, is_env_enabled from ..extras.packages import is_transformers_version_greater_than from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.checkpointing import prepare_model_for_training from .model_utils.embedding import resize_embedding_layer from .model_utils.longlora import configure_longlora from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.packing import configure_packing from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model from .model_utils.visual import ( autocast_projector_dtype, configure_visual_model, get_image_seqlen, get_patch_size, get_vision_feature_select_strategy, ) if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin from trl import AutoModelForCausalLMWithValueHead from ..hparams import ModelArguments logger = logging.get_logger(__name__) def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None: if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length: tokenizer.model_max_length = model_args.model_max_length if model_args.new_special_tokens is not None: num_added_tokens = tokenizer.add_special_tokens( dict(additional_special_tokens=model_args.new_special_tokens), replace_additional_special_tokens=False, ) logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) if num_added_tokens > 0 and not model_args.resize_vocab: model_args.resize_vocab = True logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.") def patch_processor( processor: "ProcessorMixin", config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", ) -> None: setattr(processor, "tokenizer", tokenizer) if getattr(config, "vision_config", None) is not None: # visual models setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "patch_size", get_patch_size(config, processor)) setattr(processor, "image_max_pixels", model_args.image_max_pixels) setattr(processor, "image_min_pixels", model_args.image_min_pixels) setattr(processor, "video_max_pixels", model_args.video_max_pixels) setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor)) def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", init_kwargs: Dict[str, Any], is_trainable: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.infer_dtype != "auto" and not is_trainable: model_args.compute_dtype = getattr(torch, model_args.infer_dtype) else: model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) if is_torch_npu_available(): torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE")) configure_attn_implementation(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) configure_visual_model(config) configure_packing(model_args, is_trainable) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) logger.info_rank0("Using KV cache for faster generation.") if getattr(config, "model_type", None) == "qwen": setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, model_args.compute_dtype == dtype) if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn if getattr(config, "model_type", None) == "minicpmo": setattr(config, "init_audio", True) setattr(config, "init_tts", False) if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.") # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) # cast data type of the model if: # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) # 2. quantization_bit is not None (qlora) if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None: init_kwargs["torch_dtype"] = model_args.compute_dtype if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True if "device_map" not in init_kwargs and model_args.device_map: init_kwargs["device_map"] = model_args.device_map if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool, add_valuehead: bool, ) -> None: gen_config = model.generation_config # check and fix generation config if not gen_config.do_sample and ( (gen_config.temperature is not None and gen_config.temperature != 1.0) or (gen_config.top_p is not None and gen_config.top_p != 1.0) or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) ): gen_config.do_sample = True if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( model.generate.__func__ ): model.generate = MethodType(PreTrainedModel.generate, model) if add_valuehead: prepare_valuehead_model(model) if model_args.resize_vocab: resize_embedding_layer(model, tokenizer) if is_trainable: prepare_model_for_training(model, model_args) autocast_projector_dtype(model, model_args) add_z3_leaf_module(model) if not model_args.use_unsloth: print_attn_implementation(model.config) try: model.add_model_tags(["llama-factory"]) except Exception: logger.warning_rank0("Cannot properly tag the model.") def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: if isinstance(self.pretrained_model, PreTrainedModel): self.pretrained_model.tie_weights() def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: if isinstance(self.pretrained_model, PreTrainedModel): return self.pretrained_model.get_input_embeddings() def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: if isinstance(self.pretrained_model, PreTrainedModel): return self.pretrained_model.get_output_embeddings() def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: if isinstance(self.pretrained_model, PeftModel): self.pretrained_model.create_or_update_model_card(output_dir) ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))