Spaces:
Runtime error
Runtime error
File size: 5,281 Bytes
22a452a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import tempfile
from typing import TYPE_CHECKING, Dict
from huggingface_hub import DDUFEntry
from tqdm import tqdm
from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer
if is_safetensors_available():
import safetensors.torch
def _load_tokenizer_from_dduf(
cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
) -> "PreTrainedTokenizer":
"""
Load a tokenizer from a DDUF archive.
In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a
workaround by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted
files. There is an extra cost of extracting the files, but of limited impact as the tokenizer files are usually
small-ish.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
for entry_name, entry in dduf_entries.items():
if entry_name.startswith(name + "/"):
tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/"))
# need to create intermediary directory if they don't exist
os.makedirs(os.path.dirname(tmp_entry_path), exist_ok=True)
with open(tmp_entry_path, "wb") as f:
with entry.as_mmap() as mm:
f.write(mm)
return cls.from_pretrained(os.path.dirname(tmp_entry_path), **kwargs)
def _load_transformers_model_from_dduf(
cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
) -> "PreTrainedModel":
"""
Load a transformers model from a DDUF archive.
In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround
by instantiating a model from the config file and loading the weights from the DDUF archive directly.
"""
config_file = dduf_entries.get(f"{name}/config.json")
if config_file is None:
raise EnvironmentError(
f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})."
)
generation_config = dduf_entries.get(f"{name}/generation_config.json", None)
weight_files = [
entry
for entry_name, entry in dduf_entries.items()
if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors")
]
if not weight_files:
raise EnvironmentError(
f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})."
)
if not is_safetensors_available():
raise EnvironmentError(
"Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`."
)
if is_transformers_version("<", "4.47.0"):
raise ImportError(
"You need to install `transformers>4.47.0` in order to load a transformers model from a DDUF file. "
"You can install it with: `pip install --upgrade transformers`"
)
with tempfile.TemporaryDirectory() as tmp_dir:
from transformers import AutoConfig, GenerationConfig
tmp_config_file = os.path.join(tmp_dir, "config.json")
with open(tmp_config_file, "w") as f:
f.write(config_file.read_text())
config = AutoConfig.from_pretrained(tmp_config_file)
if generation_config is not None:
tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json")
with open(tmp_generation_config_file, "w") as f:
f.write(generation_config.read_text())
generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file)
state_dict = {}
with contextlib.ExitStack() as stack:
for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files
# Memory-map the safetensors file
mmap = stack.enter_context(entry.as_mmap())
# Load tensors from the memory-mapped file
tensors = safetensors.torch.load(mmap)
# Update the state dictionary with tensors
state_dict.update(tensors)
return cls.from_pretrained(
pretrained_model_name_or_path=None,
config=config,
generation_config=generation_config,
state_dict=state_dict,
**kwargs,
)
|