Spaces:
Runtime error
Runtime error
File size: 5,230 Bytes
20076b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
try_download_model_from_ms(model_args)
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**config_kwargs,
)
patch_tokenizer(tokenizer)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
model = None
if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = {
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": get_current_device(),
"rope_scaling": getattr(config, "rope_scaling", None),
}
if getattr(config, "model_type", None) == "llama":
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
elif getattr(config, "model_type", None) == "mistral":
model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs)
else:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
if model_args.adapter_name_or_path:
model_args.adapter_name_or_path = None
logger.warning("Unsloth does not support loading adapters.")
if model is None:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs,
)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False)
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
logger.info(
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
)
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
return model, tokenizer
|