Spaces:
Sleeping
Sleeping
import difflib | |
import os | |
import io | |
import shutil | |
import struct | |
import sys | |
import torch | |
import tarfile | |
import tempfile | |
import warnings | |
from contextlib import closing, contextmanager | |
from enum import Enum | |
from ._utils import _import_dotted_name | |
from torch._sources import get_source_lines_and_file | |
from torch.types import Storage | |
from torch.storage import _get_dtype_from_pickle_storage_type | |
from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List | |
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+ | |
import copyreg | |
import pickle | |
import torch._weights_only_unpickler as _weights_only_unpickler | |
DEFAULT_PROTOCOL = 2 | |
LONG_SIZE = struct.Struct('=l').size | |
INT_SIZE = struct.Struct('=i').size | |
SHORT_SIZE = struct.Struct('=h').size | |
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c | |
PROTOCOL_VERSION = 1001 | |
STORAGE_KEY_SEPARATOR = ',' | |
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] | |
MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] | |
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] | |
__all__ = [ | |
'SourceChangeWarning', | |
'mkdtemp', | |
'register_package', | |
'check_module_version_greater_or_equal', | |
'validate_cuda_device', | |
'validate_hpu_device', | |
'location_tag', | |
'default_restore_location', | |
'normalize_storage_type', | |
'storage_to_tensor_type', | |
'save', | |
'load', | |
'StorageType', | |
'LoadEndianness', | |
'get_default_load_endianness', | |
'set_default_load_endianness', | |
] | |
class SourceChangeWarning(Warning): | |
pass | |
def mkdtemp(): | |
path = tempfile.mkdtemp() | |
try: | |
yield path | |
finally: | |
shutil.rmtree(path) | |
_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = [] | |
class LoadEndianness(Enum): | |
NATIVE = 1 | |
LITTLE = 2 | |
BIG = 3 | |
_default_load_endian: Optional[LoadEndianness] = None | |
def get_default_load_endianness() -> Optional[LoadEndianness]: | |
''' | |
Get fallback byte order for loading files | |
If byteorder mark is not present in saved checkpoint, | |
this byte order is used as fallback. | |
By default, it's "native" byte order. | |
Returns: | |
default_load_endian: Optional[LoadEndianness] | |
''' | |
return _default_load_endian | |
def set_default_load_endianness(endianness): | |
''' | |
Set fallback byte order for loading files | |
If byteorder mark is not present in saved checkpoint, | |
this byte order is used as fallback. | |
By default, it's "native" byte order. | |
Args: | |
endianness: the new fallback byte order | |
''' | |
global _default_load_endian | |
if not isinstance(endianness, LoadEndianness) and endianness is not None: | |
raise TypeError("Invalid argument type in function set_default_load_endianness") | |
_default_load_endian = endianness | |
def _is_zipfile(f) -> bool: | |
# This is a stricter implementation than zipfile.is_zipfile(). | |
# zipfile.is_zipfile() is True if the magic number appears anywhere in the | |
# binary. Since we expect the files here to be generated by torch.save or | |
# torch.jit.save, it's safe to only check the start bytes and avoid | |
# collisions and assume the zip has only 1 file. | |
# See bugs.python.org/issue28494. | |
start = f.tell() | |
# Read the first few bytes and match against the ZIP file signature | |
local_header_magic_number = b'PK\x03\x04' | |
read_bytes = f.read(len(local_header_magic_number)) | |
f.seek(start) | |
return read_bytes == local_header_magic_number | |
def register_package( | |
priority: int, | |
tagger: Callable[[STORAGE], Optional[str]], | |
deserializer: Callable[[STORAGE, str], Optional[STORAGE]] | |
): | |
''' | |
Registers callables for tagging and deserializing storage objects with an associated priority. | |
Tagging associates a device with a storage object at save time while deserializing moves a | |
storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer` | |
are run in the order given by their :attr:`priority` until a tagger/deserializer returns a | |
value that is not `None`. | |
To override the deserialization behavior for a device in the global registry, one can register a | |
tagger with a higher priority than the existing tagger. | |
This function can also be used to register a tagger and deserializer for new devices. | |
Args: | |
priority: Indicates the priority associated with the tagger and deserializer, where a lower | |
value indicates higher priority. | |
tagger: Callable that takes in a storage object and returns its tagged device as a string | |
or None. | |
deserializer: Callable that takes in storage object and a device string and returns a storage | |
object on the appropriate device or None. | |
Returns: | |
`None` | |
Example: | |
>>> def ipu_tag(obj): | |
>>> if obj.device.type == 'ipu': | |
>>> return 'ipu' | |
>>> def ipu_deserialize(obj, location): | |
>>> if location.startswith('ipu'): | |
>>> ipu = getattr(torch, "ipu", None) | |
>>> assert ipu is not None, "IPU device module is not loaded" | |
>>> assert torch.ipu.is_available(), "ipu is not available" | |
>>> return obj.ipu(location) | |
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) | |
''' | |
queue_elem = (priority, tagger, deserializer) | |
_package_registry.append(queue_elem) | |
_package_registry.sort() | |
def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): | |
''' | |
Check if a module's version satisfies requirements | |
Usually, a module's version string will be like 'x.y.z', which would be represented | |
as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version | |
string does not match the given tuple's format up to the length of the tuple, then | |
error and exit or emit a warning. | |
Args: | |
module: the module to check the version of | |
req_version_tuple: tuple (usually of ints) representing the required version | |
error_if_malformed: whether we should exit if module version string is malformed | |
Returns: | |
requirement_is_met: bool | |
''' | |
try: | |
version_strs = module.__version__.split('.') | |
# Cast module version fields to match the types of the required version | |
module_version = tuple( | |
type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) | |
) | |
requirement_is_met = module_version >= req_version_tuple | |
except Exception as e: | |
message = ( | |
f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" | |
f" with tuple {str(req_version_tuple)}" | |
) | |
if error_if_malformed: | |
raise RuntimeError(message) from e | |
else: | |
warnings.warn(message + ', but continuing assuming that requirement is met') | |
requirement_is_met = True | |
return requirement_is_met | |
def _cpu_tag(obj): | |
if obj.device.type == 'cpu': | |
return 'cpu' | |
def _cuda_tag(obj): | |
if obj.device.type == 'cuda': | |
return 'cuda:' + str(obj.device.index) | |
def _hpu_tag(obj): | |
if obj.device.type == 'hpu': | |
return 'hpu:' + str(obj.device.index) | |
def _mps_tag(obj): | |
if obj.device.type == 'mps': | |
return 'mps' | |
def _meta_tag(obj): | |
if obj.device.type == 'meta': | |
return 'meta' | |
def _privateuse1_tag(obj): | |
backend_name = torch._C._get_privateuse1_backend_name() | |
if obj.device.type == backend_name: | |
if obj.device.index is None: | |
return backend_name | |
else: | |
return backend_name + ':' + str(obj.device.index) | |
def _cpu_deserialize(obj, location): | |
if location == 'cpu': | |
return obj | |
def validate_cuda_device(location): | |
device = torch.cuda._utils._get_device_index(location, True) | |
if not torch.cuda.is_available(): | |
raise RuntimeError('Attempting to deserialize object on a CUDA ' | |
'device but torch.cuda.is_available() is False. ' | |
'If you are running on a CPU-only machine, ' | |
'please use torch.load with map_location=torch.device(\'cpu\') ' | |
'to map your storages to the CPU.') | |
device_count = torch.cuda.device_count() | |
if device >= device_count: | |
raise RuntimeError('Attempting to deserialize object on CUDA device ' | |
f'{device} but torch.cuda.device_count() is {device_count}. Please use ' | |
'torch.load with map_location to map your storages ' | |
'to an existing device.') | |
return device | |
def _cuda_deserialize(obj, location): | |
if location.startswith('cuda'): | |
device = validate_cuda_device(location) | |
if getattr(obj, "_torch_load_uninitialized", False): | |
with torch.cuda.device(device): | |
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) | |
else: | |
return obj.cuda(device) | |
def validate_hpu_device(location): | |
hpu = getattr(torch, "hpu", None) | |
assert hpu is not None, "HPU device module is not loaded" | |
device = hpu._utils._get_device_index(location, optional=True) | |
if not hpu.is_available(): | |
raise RuntimeError('Attempting to deserialize object on a HPU ' | |
'device but torch.hpu.is_available() is False. ' | |
'If you are running on a CPU-only machine, ' | |
'please use torch.load with map_location=torch.device(\'cpu\') ' | |
'to map your storages to the CPU.') | |
device_count = hpu.device_count() | |
if device >= device_count: | |
raise RuntimeError('Attempting to deserialize object on HPU device ' | |
f'{device} but torch.hpu.device_count() is {device_count}. Please use ' | |
'torch.load with map_location to map your storages ' | |
'to an existing device.') | |
return device | |
def _hpu_deserialize(obj, location): | |
if location.startswith('hpu'): | |
hpu = getattr(torch, "hpu", None) | |
assert hpu is not None, "HPU device module is not loaded" | |
device = validate_hpu_device(location) | |
if getattr(obj, "_torch_load_uninitialized", False): | |
with hpu.device(device): | |
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) | |
else: | |
return obj.hpu(device) | |
def _mps_deserialize(obj, location): | |
if location.startswith('mps'): | |
return obj.mps() | |
def _meta_deserialize(obj, location): | |
if location == 'meta': | |
return torch.UntypedStorage(obj.nbytes(), device='meta') | |
def _validate_privateuse1_device(location, backend_name): | |
''' | |
Check whether the device index of privateuse1 is valid | |
Register a device_module of privateuse1 by torch._register_device_module. | |
Implement the following methods in device_module like cuda: | |
device_module._utils._get_device_index(location, True), | |
device_module.device_count(). | |
Args: | |
location: string of device | |
backend_name: the name of privateuse1, which can be renamed | |
Returns: | |
device_index: int | |
''' | |
if not hasattr(torch, backend_name): | |
raise RuntimeError(f'The {backend_name.upper()} device module is not registered. ' | |
'If you are running on a CPU-only machine, ' | |
'please use torch.load with map_location=torch.device(\'cpu\') ' | |
'to map your storages to the CPU.') | |
device_module = getattr(torch, backend_name) | |
if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'): | |
device_index = device_module._utils._get_device_index(location, True) | |
else: | |
device = torch.device(location) | |
device_index = device.index if device.index else 0 | |
if hasattr(device_module, 'is_available') and not device_module.is_available(): | |
raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} ' | |
f'device but torch.{backend_name}.is_available() is False. ' | |
'If you are running on a CPU-only machine, ' | |
'please use torch.load with map_location=torch.device(\'cpu\') ' | |
'to map your storages to the CPU.') | |
if hasattr(device_module, 'device_count'): | |
device_count = device_module.device_count() | |
if device_index >= device_count: | |
raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device ' | |
f'{device_index} but torch.{backend_name}.device_count() is {device_count}. ' | |
'Please use torch.load with map_location to map your storages ' | |
'to an existing device.') | |
return device_index | |
def _privateuse1_deserialize(obj, location): | |
backend_name = torch._C._get_privateuse1_backend_name() | |
if location.startswith(backend_name): | |
if not hasattr(obj, backend_name): | |
raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device ' | |
f'but torch.storage._StorageBase.{backend_name}() or ' | |
f'torch.storage.TypedStorage.{backend_name}() is not generated. ' | |
'Please use torch.utils.generate_methods_for_privateuse1_backend ' | |
f'to generate storage.{backend_name}() method first.') | |
device_index = _validate_privateuse1_device(location, backend_name) | |
return getattr(obj, backend_name)(device_index) | |
register_package(10, _cpu_tag, _cpu_deserialize) | |
register_package(20, _cuda_tag, _cuda_deserialize) | |
register_package(21, _mps_tag, _mps_deserialize) | |
register_package(22, _meta_tag, _meta_deserialize) | |
register_package(23, _privateuse1_tag, _privateuse1_deserialize) | |
register_package(24, _hpu_tag, _hpu_deserialize) | |
def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): | |
for _, tagger, _ in _package_registry: | |
location = tagger(storage) | |
if location: | |
return location | |
raise RuntimeError("don't know how to determine data location of " | |
+ torch.typename(storage)) | |
def default_restore_location(storage, location): | |
for _, _, fn in _package_registry: | |
result = fn(storage, location) | |
if result is not None: | |
return result | |
raise RuntimeError("don't know how to restore data location of " | |
+ torch.typename(storage) + " (tagged with " | |
+ location + ")") | |
def normalize_storage_type(storage_type): | |
return getattr(torch, storage_type.__name__) | |
def storage_to_tensor_type(storage): | |
storage_type = type(storage) | |
module = _import_dotted_name(storage_type.__module__) | |
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) | |
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]: | |
return isinstance(name_or_buffer, (str, os.PathLike)) | |
class _opener: | |
def __init__(self, file_like): | |
self.file_like = file_like | |
def __enter__(self): | |
return self.file_like | |
def __exit__(self, *args): | |
pass | |
class _open_file(_opener): | |
def __init__(self, name, mode): | |
super().__init__(open(name, mode)) | |
def __exit__(self, *args): | |
self.file_like.close() | |
class _open_buffer_reader(_opener): | |
def __init__(self, buffer): | |
super().__init__(buffer) | |
_check_seekable(buffer) | |
class _open_buffer_writer(_opener): | |
def __exit__(self, *args): | |
self.file_like.flush() | |
def _open_file_like(name_or_buffer, mode): | |
if _is_path(name_or_buffer): | |
return _open_file(name_or_buffer, mode) | |
else: | |
if 'w' in mode: | |
return _open_buffer_writer(name_or_buffer) | |
elif 'r' in mode: | |
return _open_buffer_reader(name_or_buffer) | |
else: | |
raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") | |
class _open_zipfile_reader(_opener): | |
def __init__(self, name_or_buffer) -> None: | |
super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) | |
class _open_zipfile_writer_file(_opener): | |
def __init__(self, name) -> None: | |
self.file_stream = None | |
self.name = str(name) | |
try: | |
self.name.encode('ascii') | |
except UnicodeEncodeError: | |
# PyTorchFileWriter only supports ascii filename. | |
# For filenames with non-ascii characters, we rely on Python | |
# for writing out the file. | |
self.file_stream = io.FileIO(self.name, mode='w') | |
super().__init__(torch._C.PyTorchFileWriter(self.file_stream)) | |
else: | |
super().__init__(torch._C.PyTorchFileWriter(self.name)) | |
def __exit__(self, *args) -> None: | |
self.file_like.write_end_of_file() | |
if self.file_stream is not None: | |
self.file_stream.close() | |
class _open_zipfile_writer_buffer(_opener): | |
def __init__(self, buffer) -> None: | |
if not callable(getattr(buffer, "write", None)): | |
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" | |
if not hasattr(buffer, "write"): | |
raise AttributeError(msg) | |
raise TypeError(msg) | |
self.buffer = buffer | |
super().__init__(torch._C.PyTorchFileWriter(buffer)) | |
def __exit__(self, *args) -> None: | |
self.file_like.write_end_of_file() | |
self.buffer.flush() | |
def _open_zipfile_writer(name_or_buffer): | |
container: Type[_opener] | |
if _is_path(name_or_buffer): | |
container = _open_zipfile_writer_file | |
else: | |
container = _open_zipfile_writer_buffer | |
return container(name_or_buffer) | |
def _is_compressed_file(f) -> bool: | |
compress_modules = ['gzip'] | |
try: | |
return f.__module__ in compress_modules | |
except AttributeError: | |
return False | |
def _should_read_directly(f): | |
""" | |
Checks if f is a file that should be read directly. It should be read | |
directly if it is backed by a real file (has a fileno) and is not a | |
a compressed file (e.g. gzip) | |
""" | |
if _is_compressed_file(f): | |
return False | |
try: | |
return f.fileno() >= 0 | |
except io.UnsupportedOperation: | |
return False | |
except AttributeError: | |
return False | |
def _check_seekable(f) -> bool: | |
def raise_err_msg(patterns, e): | |
for p in patterns: | |
if p in str(e): | |
msg = (str(e) + ". You can only torch.load from a file that is seekable." | |
+ " Please pre-load the data into a buffer like io.BytesIO and" | |
+ " try to load from it instead.") | |
raise type(e)(msg) | |
raise e | |
try: | |
f.seek(f.tell()) | |
return True | |
except (io.UnsupportedOperation, AttributeError) as e: | |
raise_err_msg(["seek", "tell"], e) | |
return False | |
def _check_dill_version(pickle_module) -> None: | |
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version. | |
If dill version is lower than 0.3.1, a ValueError is raised. | |
Args: | |
pickle_module: module used for pickling metadata and objects | |
''' | |
if pickle_module is not None and pickle_module.__name__ == 'dill': | |
required_dill_version = (0, 3, 1) | |
if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): | |
raise ValueError(( | |
"'torch' supports dill >= {}, but you have dill {}." | |
" Please upgrade dill or switch to 'pickle'" | |
).format( | |
'.'.join([str(num) for num in required_dill_version]), | |
pickle_module.__version__ | |
)) | |
def _check_save_filelike(f): | |
if not _is_path(f) and not hasattr(f, 'write'): | |
raise AttributeError( | |
"expected 'f' to be string, path, or a file-like object with " | |
"a 'write' attribute") | |
def save( | |
obj: object, | |
f: FILE_LIKE, | |
pickle_module: Any = pickle, | |
pickle_protocol: int = DEFAULT_PROTOCOL, | |
_use_new_zipfile_serialization: bool = True, | |
_disable_byteorder_record: bool = False | |
) -> None: | |
# Reference: https://github.com/pytorch/pytorch/issues/54354 | |
# The first line of this docstring overrides the one Sphinx generates for the | |
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from | |
# the build environment (e.g. `<module 'pickle' from '/leaked/path'). | |
"""save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) | |
Saves an object to a disk file. | |
See also: :ref:`saving-loading-tensors` | |
Args: | |
obj: saved object | |
f: a file-like object (has to implement write and flush) or a string or | |
os.PathLike object containing a file name | |
pickle_module: module used for pickling metadata and objects | |
pickle_protocol: can be specified to override the default protocol | |
.. note:: | |
A common PyTorch convention is to save tensors using .pt file extension. | |
.. note:: | |
PyTorch preserves storage sharing across serialization. See | |
:ref:`preserve-storage-sharing` for more details. | |
.. note:: | |
The 1.6 release of PyTorch switched ``torch.save`` to use a new | |
zipfile-based file format. ``torch.load`` still retains the ability to | |
load files in the old format. If for any reason you want ``torch.save`` | |
to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. | |
Example: | |
>>> # xdoctest: +SKIP("makes cwd dirty") | |
>>> # Save to file | |
>>> x = torch.tensor([0, 1, 2, 3, 4]) | |
>>> torch.save(x, 'tensor.pt') | |
>>> # Save to io.BytesIO buffer | |
>>> buffer = io.BytesIO() | |
>>> torch.save(x, buffer) | |
""" | |
torch._C._log_api_usage_once("torch.save") | |
_check_dill_version(pickle_module) | |
_check_save_filelike(f) | |
if _use_new_zipfile_serialization: | |
with _open_zipfile_writer(f) as opened_zipfile: | |
_save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record) | |
return | |
else: | |
with _open_file_like(f, 'wb') as opened_file: | |
_legacy_save(obj, opened_file, pickle_module, pickle_protocol) | |
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: | |
import torch.nn as nn | |
serialized_container_types = {} | |
serialized_storages = {} | |
# Since loading storages that view the same data with different dtypes is | |
# not supported, we need to keep track of the dtype associated with each | |
# storage data_ptr and throw an error if the dtype is ever different. | |
# TODO: This feature could be added in the future | |
storage_dtypes: Dict[int, torch.dtype] = {} | |
def persistent_id(obj: Any) -> Optional[Tuple]: | |
# FIXME: the docs say that persistent_id should only return a string | |
# but torch store returns tuples. This works only in the binary protocol | |
# see | |
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects | |
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 | |
if isinstance(obj, type) and issubclass(obj, nn.Module): | |
if obj in serialized_container_types: | |
return None | |
serialized_container_types[obj] = True | |
source_file = source = None | |
try: | |
source_lines, _, source_file = get_source_lines_and_file(obj) | |
source = ''.join(source_lines) | |
except Exception: # saving the source is optional, so we can ignore any errors | |
warnings.warn("Couldn't retrieve source code for container of " | |
"type " + obj.__name__ + ". It won't be checked " | |
"for correctness upon loading.") | |
return ('module', obj, source_file, source) | |
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): | |
storage: torch.UntypedStorage | |
if isinstance(obj, torch.storage.TypedStorage): | |
# TODO: Once we decide to break serialization FC, this case | |
# can be deleted | |
storage = obj._untyped_storage | |
storage_dtype = obj.dtype | |
storage_type_str = obj._pickle_storage_type() | |
storage_type = getattr(torch, storage_type_str) | |
dtype = obj.dtype | |
storage_numel = obj._size() | |
elif isinstance(obj, torch.UntypedStorage): | |
storage = obj | |
storage_dtype = torch.uint8 | |
storage_type = normalize_storage_type(type(obj)) | |
dtype = torch.uint8 | |
storage_numel = storage.nbytes() | |
else: | |
raise TypeError(f'type not recognized: {type(obj)}') | |
# If storage is allocated, ensure that any other saved storages | |
# pointing to the same data all have the same dtype. If storage is | |
# not allocated, don't perform this check | |
if storage.data_ptr() != 0: | |
if storage.data_ptr() in storage_dtypes: | |
if storage_dtype != storage_dtypes[storage.data_ptr()]: | |
raise RuntimeError( | |
'Cannot save multiple tensors or storages that ' | |
'view the same data as different types') | |
else: | |
storage_dtypes[storage.data_ptr()] = storage_dtype | |
view_metadata: Optional[Tuple[str, int, int]] | |
# Offset is always 0, but we keep it for backwards compatibility | |
# with the old serialization format (which supported storage views) | |
offset = 0 | |
storage_key = str(storage._cdata) | |
location = location_tag(storage) | |
# TODO: There's an issue here with FC. It might be impossible to | |
# solve, but it's worth noting. Imagine we save a list `[storage, | |
# tensor]`, where `tensor.storage()` is the same as `storage`, and | |
# `tensor.element_size() > 1`. Let's say that `tensor.dtype == | |
# torch.float`. The storage will be serialized with element size | |
# of 1, since we're choosing to serialize the first occurance of | |
# a duplicate storage. Since this legacy serialization format saves | |
# the numel of the storage, rather than nbytes directly, we'll be | |
# effectively saving nbytes in this case. We'll be able to load it | |
# and the tensor back up with no problems in _this_ and future | |
# versions of pytorch, but in older versions, here's the problem: | |
# the storage will be loaded up as a UntypedStorage, and then the | |
# FloatTensor will loaded and the UntypedStorage will be assigned to | |
# it. Since the storage dtype does not match the tensor dtype, this | |
# will cause an error. If we reverse the list, like `[tensor, | |
# storage]`, then we will save the `tensor.storage()` as a faked | |
# `FloatStorage`, and the saved size will be the correct | |
# dtype-specific numel count that old versions expect. `tensor` | |
# will be able to load up properly in old versions, pointing to | |
# a FloatStorage. However, `storage` is still being translated to | |
# a UntypedStorage, and it will try to resolve to the same | |
# FloatStorage that `tensor` contains. This will also cause an | |
# error. It doesn't seem like there's any way around this. | |
# Probably, we just cannot maintain FC for the legacy format if the | |
# saved list contains both a tensor and a storage that point to the | |
# same data. We should still be able to maintain FC for lists of | |
# just tensors, as long as all views share the same dtype as the | |
# tensor they are viewing. | |
if storage_key not in serialized_storages: | |
serialized_storages[storage_key] = (storage, dtype) | |
is_view = storage._cdata != storage._cdata | |
if is_view: | |
view_metadata = (str(storage._cdata), offset, storage.nbytes()) | |
else: | |
view_metadata = None | |
res = ('storage', | |
storage_type, | |
storage_key, | |
location, | |
storage_numel, | |
view_metadata) | |
return res | |
return None | |
sys_info = dict( | |
protocol_version=PROTOCOL_VERSION, | |
little_endian=sys.byteorder == 'little', | |
type_sizes=dict( | |
short=SHORT_SIZE, | |
int=INT_SIZE, | |
long=LONG_SIZE, | |
), | |
) | |
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) | |
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) | |
pickle_module.dump(sys_info, f, protocol=pickle_protocol) | |
pickler = pickle_module.Pickler(f, protocol=pickle_protocol) | |
pickler.persistent_id = persistent_id | |
pickler.dump(obj) | |
serialized_storage_keys = sorted(serialized_storages.keys()) | |
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) | |
f.flush() | |
for key in serialized_storage_keys: | |
storage, dtype = serialized_storages[key] | |
storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype)) | |
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): | |
serialized_storages = {} | |
id_map: Dict[int, str] = {} | |
# Since loading storages that view the same data with different dtypes is | |
# not supported, we need to keep track of the dtype associated with each | |
# storage data_ptr and throw an error if the dtype is ever different. | |
# TODO: This feature could be added in the future | |
storage_dtypes: Dict[int, torch.dtype] = {} | |
def persistent_id(obj): | |
# FIXME: the docs say that persistent_id should only return a string | |
# but torch store returns tuples. This works only in the binary protocol | |
# see | |
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects | |
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 | |
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): | |
if isinstance(obj, torch.storage.TypedStorage): | |
# TODO: Once we decide to break serialization FC, this case | |
# can be deleted | |
storage = obj._untyped_storage | |
storage_dtype = obj.dtype | |
storage_type_str = obj._pickle_storage_type() | |
storage_type = getattr(torch, storage_type_str) | |
storage_numel = obj._size() | |
else: | |
storage = obj | |
storage_dtype = torch.uint8 | |
storage_type = normalize_storage_type(type(obj)) | |
storage_numel = storage.nbytes() | |
# If storage is allocated, ensure that any other saved storages | |
# pointing to the same data all have the same dtype. If storage is | |
# not allocated, don't perform this check | |
if storage.data_ptr() != 0: | |
if storage.data_ptr() in storage_dtypes: | |
if storage_dtype != storage_dtypes[storage.data_ptr()]: | |
raise RuntimeError( | |
'Cannot save multiple tensors or storages that ' | |
'view the same data as different types') | |
else: | |
storage_dtypes[storage.data_ptr()] = storage_dtype | |
storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) | |
location = location_tag(storage) | |
serialized_storages[storage_key] = storage | |
return ('storage', | |
storage_type, | |
storage_key, | |
location, | |
storage_numel) | |
return None | |
# Write the pickle data for `obj` | |
data_buf = io.BytesIO() | |
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) | |
pickler.persistent_id = persistent_id | |
pickler.dump(obj) | |
data_value = data_buf.getvalue() | |
zip_file.write_record('data.pkl', data_value, len(data_value)) | |
# Write byte order marker | |
if not _disable_byteorder_record: | |
if sys.byteorder not in ['little', 'big']: | |
raise ValueError('Unknown endianness type: ' + sys.byteorder) | |
zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder)) | |
# Write each tensor to a file named tensor/the_tensor_key in the zip archive | |
for key in sorted(serialized_storages.keys()): | |
name = f'data/{key}' | |
storage = serialized_storages[key] | |
# given that we copy things around anyway, we might use storage.cpu() | |
# this means to that to get tensors serialized, you need to implement | |
# .cpu() on the underlying Storage | |
if storage.device.type != 'cpu': | |
storage = storage.cpu() | |
# Now that it is on the CPU we can directly copy it into the zip file | |
num_bytes = storage.nbytes() | |
zip_file.write_record(name, storage, num_bytes) | |
def load( | |
f: FILE_LIKE, | |
map_location: MAP_LOCATION = None, | |
pickle_module: Any = None, | |
*, | |
weights_only: bool = False, | |
mmap: Optional[bool] = None, | |
**pickle_load_args: Any | |
) -> Any: | |
# Reference: https://github.com/pytorch/pytorch/issues/54354 | |
# The first line of this docstring overrides the one Sphinx generates for the | |
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from | |
# the build environment (e.g. `<module 'pickle' from '/leaked/path'). | |
"""load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args) | |
Loads an object saved with :func:`torch.save` from a file. | |
:func:`torch.load` uses Python's unpickling facilities but treats storages, | |
which underlie tensors, specially. They are first deserialized on the | |
CPU and are then moved to the device they were saved from. If this fails | |
(e.g. because the run time system doesn't have certain devices), an exception | |
is raised. However, storages can be dynamically remapped to an alternative | |
set of devices using the :attr:`map_location` argument. | |
If :attr:`map_location` is a callable, it will be called once for each serialized | |
storage with two arguments: storage and location. The storage argument | |
will be the initial deserialization of the storage, residing on the CPU. | |
Each serialized storage has a location tag associated with it which | |
identifies the device it was saved from, and this tag is the second | |
argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` | |
for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. | |
:attr:`map_location` should return either ``None`` or a storage. If | |
:attr:`map_location` returns a storage, it will be used as the final deserialized | |
object, already moved to the right device. Otherwise, :func:`torch.load` will | |
fall back to the default behavior, as if :attr:`map_location` wasn't specified. | |
If :attr:`map_location` is a :class:`torch.device` object or a string containing | |
a device tag, it indicates the location where all tensors should be loaded. | |
Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags | |
appearing in the file (keys), to ones that specify where to put the | |
storages (values). | |
User extensions can register their own location tags and tagging and | |
deserialization methods using :func:`torch.serialization.register_package`. | |
Args: | |
f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), | |
or a string or os.PathLike object containing a file name | |
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage | |
locations | |
pickle_module: module used for unpickling metadata and objects (has to | |
match the :attr:`pickle_module` used to serialize file) | |
weights_only: Indicates whether unpickler should be restricted to | |
loading only tensors, primitive types and dictionaries | |
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. | |
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they | |
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This | |
second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the | |
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped. | |
pickle_load_args: (Python 3 only) optional keyword arguments passed over to | |
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., | |
:attr:`errors=...`. | |
.. warning:: | |
:func:`torch.load()` unless `weights_only` parameter is set to `True`, | |
uses ``pickle`` module implicitly, which is known to be insecure. | |
It is possible to construct malicious pickle data which will execute arbitrary code | |
during unpickling. Never load data that could have come from an untrusted | |
source in an unsafe mode, or that could have been tampered with. **Only load data you trust**. | |
.. note:: | |
When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors | |
will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` | |
and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. | |
.. note:: | |
By default, we decode byte strings as ``utf-8``. This is to avoid a common error | |
case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` | |
when loading files saved by Python 2 in Python 3. If this default | |
is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how | |
these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them | |
to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them | |
as byte arrays which can be decoded later with ``byte_array.decode(...)``. | |
Example: | |
>>> # xdoctest: +SKIP("undefined filepaths") | |
>>> torch.load('tensors.pt', weights_only=True) | |
# Load all tensors onto the CPU | |
>>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True) | |
# Load all tensors onto the CPU, using a function | |
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True) | |
# Load all tensors onto GPU 1 | |
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True) | |
# Map tensors from GPU 1 to GPU 0 | |
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True) | |
# Load tensor from io.BytesIO object | |
# Loading from a buffer setting weights_only=False, warning this can be unsafe | |
>>> with open('tensor.pt', 'rb') as f: | |
... buffer = io.BytesIO(f.read()) | |
>>> torch.load(buffer, weights_only=False) | |
# Load a module with 'ascii' encoding for unpickling | |
# Loading from a module setting weights_only=False, warning this can be unsafe | |
>>> torch.load('module.pt', encoding='ascii', weights_only=False) | |
""" | |
torch._C._log_api_usage_once("torch.load") | |
UNSAFE_MESSAGE = ( | |
"Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" | |
" will likely succeed, but it can result in arbitrary code execution." | |
"Do it only if you get the file from a trusted source. WeightsUnpickler error: " | |
) | |
# Add ability to force safe only weight loads via environment variable | |
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: | |
weights_only = True | |
if weights_only: | |
if pickle_module is not None: | |
raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") | |
else: | |
if pickle_module is None: | |
pickle_module = pickle | |
# make flipping default BC-compatible | |
if mmap is None: | |
mmap = False | |
_check_dill_version(pickle_module) | |
if 'encoding' not in pickle_load_args.keys(): | |
pickle_load_args['encoding'] = 'utf-8' | |
with _open_file_like(f, 'rb') as opened_file: | |
if _is_zipfile(opened_file): | |
# The zipfile reader is going to advance the current file position. | |
# If we want to actually tail call to torch.jit.load, we need to | |
# reset back to the original position. | |
orig_position = opened_file.tell() | |
overall_storage = None | |
with _open_zipfile_reader(opened_file) as opened_zipfile: | |
if _is_torchscript_zip(opened_zipfile): | |
warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" | |
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" | |
" silence this warning)", UserWarning) | |
opened_file.seek(orig_position) | |
return torch.jit.load(opened_file, map_location=map_location) | |
if mmap: | |
if not _is_path(f): | |
raise ValueError("f must be a file path in order to use the mmap argument") | |
size = os.path.getsize(f) | |
overall_storage = torch.UntypedStorage.from_file(os.fspath(f), False, size) | |
if weights_only: | |
try: | |
return _load(opened_zipfile, | |
map_location, | |
_weights_only_unpickler, | |
overall_storage=overall_storage, | |
**pickle_load_args) | |
except RuntimeError as e: | |
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None | |
return _load(opened_zipfile, | |
map_location, | |
pickle_module, | |
overall_storage=overall_storage, | |
**pickle_load_args) | |
if mmap: | |
f_name = "" if not isinstance(f, str) else f"{f}, " | |
raise RuntimeError("mmap can only be used with files saved with " | |
f"`torch.save({f_name}_use_new_zipfile_serialization=True), " | |
"please torch.save your checkpoint with this option in order to use mmap.") | |
if weights_only: | |
try: | |
return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args) | |
except RuntimeError as e: | |
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None | |
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) | |
# Register pickling support for layout instances such as | |
# torch.sparse_coo, etc | |
def _get_layout(name): | |
"""Get layout extension object from its string representation. | |
""" | |
cache = _get_layout.cache # type: ignore[attr-defined] | |
if not cache: | |
for v in torch.__dict__.values(): | |
if isinstance(v, torch.layout): | |
cache[str(v)] = v | |
return cache[name] | |
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 | |
_get_layout.cache = {} # type: ignore[attr-defined] | |
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) | |
def _legacy_load(f, map_location, pickle_module, **pickle_load_args): | |
deserialized_objects: Dict[int, Any] = {} | |
restore_location = _get_restore_location(map_location) | |
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] | |
def find_class(self, mod_name, name): | |
if type(name) is str and 'Storage' in name: | |
try: | |
return StorageType(name) | |
except KeyError: | |
pass | |
return super().find_class(mod_name, name) | |
def _check_container_source(container_type, source_file, original_source): | |
try: | |
current_source = ''.join(get_source_lines_and_file(container_type)[0]) | |
except Exception: # saving the source is optional, so we can ignore any errors | |
warnings.warn("Couldn't retrieve source code for container of " | |
"type " + container_type.__name__ + ". It won't be checked " | |
"for correctness upon loading.") | |
return | |
if original_source != current_source: | |
if container_type.dump_patches: | |
file_name = container_type.__name__ + '.patch' | |
diff = difflib.unified_diff(current_source.split('\n'), | |
original_source.split('\n'), | |
source_file, | |
source_file, lineterm="") | |
lines = '\n'.join(diff) | |
try: | |
with open(file_name, 'a+') as f: | |
file_size = f.seek(0, 2) | |
f.seek(0) | |
if file_size == 0: | |
f.write(lines) | |
elif file_size != len(lines) or f.read() != lines: | |
raise OSError | |
msg = ("Saved a reverse patch to " + file_name + ". " | |
"Run `patch -p0 < " + file_name + "` to revert your " | |
"changes.") | |
except OSError: | |
msg = ("Tried to save a patch, but couldn't create a " | |
"writable file " + file_name + ". Make sure it " | |
"doesn't exist and your working directory is " | |
"writable.") | |
else: | |
msg = ("you can retrieve the original source code by " | |
"accessing the object's source attribute or set " | |
"`torch.nn.Module.dump_patches = True` and use the " | |
"patch tool to revert the changes.") | |
msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" | |
warnings.warn(msg, SourceChangeWarning) | |
def legacy_load(f): | |
deserialized_objects: Dict[int, Any] = {} | |
def persistent_load(saved_id): | |
if isinstance(saved_id, tuple): | |
# Ignore containers that don't have any sources saved | |
if all(saved_id[1:]): | |
_check_container_source(*saved_id) | |
return saved_id[0] | |
return deserialized_objects[int(saved_id)] | |
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ | |
mkdtemp() as tmpdir: | |
tar.extract('storages', path=tmpdir) | |
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: | |
num_storages = pickle_module.load(f, **pickle_load_args) | |
for i in range(num_storages): | |
args = pickle_module.load(f, **pickle_load_args) | |
key, location, storage_type = args | |
dtype = storage_type._dtype | |
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) | |
obj = restore_location(obj, location) | |
# TODO: Once we decide to break serialization FC, we can | |
# stop wrapping with TypedStorage | |
deserialized_objects[key] = torch.storage.TypedStorage( | |
wrap_storage=obj, | |
dtype=dtype, | |
_internal=True) | |
storage_views = pickle_module.load(f, **pickle_load_args) | |
for target_cdata, root_cdata, offset, numel in storage_views: | |
root = deserialized_objects[root_cdata] | |
element_size = torch._utils._element_size(root.dtype) | |
offset_bytes = offset * element_size | |
# TODO: Once we decide to break serialization FC, we can | |
# stop wrapping with TypedStorage | |
deserialized_objects[target_cdata] = torch.storage.TypedStorage( | |
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size], | |
dtype=root.dtype, | |
_internal=True) | |
tar.extract('tensors', path=tmpdir) | |
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: | |
num_tensors = pickle_module.load(f, **pickle_load_args) | |
for _ in range(num_tensors): | |
args = pickle_module.load(f, **pickle_load_args) | |
key, storage_id, original_tensor_type = args | |
storage = deserialized_objects[storage_id] | |
ndim, = struct.unpack('<i', f.read(4)) | |
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes | |
f.read(4) | |
numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) | |
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) | |
storage_offset, = struct.unpack('<q', f.read(8)) | |
tensor = torch.empty((0,), dtype=storage.dtype).set_( | |
storage._untyped_storage, storage_offset, numel, stride) | |
deserialized_objects[key] = tensor | |
pickle_file = tar.extractfile('pickle') | |
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args) | |
unpickler.persistent_load = persistent_load | |
result = unpickler.load() | |
return result | |
deserialized_objects = {} | |
def persistent_load(saved_id): | |
assert isinstance(saved_id, tuple) | |
typename = _maybe_decode_ascii(saved_id[0]) | |
data = saved_id[1:] | |
if typename == 'module': | |
# Ignore containers that don't have any sources saved | |
if all(data[1:]): | |
_check_container_source(*data) | |
return data[0] | |
elif typename == 'storage': | |
storage_type, root_key, location, numel, view_metadata = data | |
location = _maybe_decode_ascii(location) | |
dtype = storage_type.dtype | |
nbytes = numel * torch._utils._element_size(dtype) | |
if root_key not in deserialized_objects: | |
if torch._guards.active_fake_mode() is not None: | |
obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta')) | |
else: | |
obj = cast(Storage, torch.UntypedStorage(nbytes)) | |
obj._torch_load_uninitialized = True | |
obj = restore_location(obj, location) | |
# TODO: Once we decide to break serialization FC, we can | |
# stop wrapping with TypedStorage | |
typed_storage = torch.storage.TypedStorage( | |
wrap_storage=obj, | |
dtype=dtype, | |
_internal=True) | |
deserialized_objects[root_key] = typed_storage | |
else: | |
typed_storage = deserialized_objects[root_key] | |
if typed_storage._data_ptr() == 0: | |
typed_storage = torch.storage.TypedStorage( | |
device=typed_storage._untyped_storage.device, | |
dtype=dtype, | |
_internal=True) | |
if view_metadata is not None: | |
view_key, offset, view_size = view_metadata | |
offset_bytes = offset * torch._utils._element_size(dtype) | |
view_size_bytes = view_size * torch._utils._element_size(dtype) | |
if view_key not in deserialized_objects: | |
# TODO: Once we decide to break serialization FC, we can | |
# stop wrapping with TypedStorage | |
deserialized_objects[view_key] = torch.storage.TypedStorage( | |
wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes], | |
dtype=dtype, | |
_internal=True) | |
res = deserialized_objects[view_key] | |
else: | |
res = typed_storage | |
return res | |
else: | |
raise RuntimeError(f"Unknown saved id type: {saved_id[0]}") | |
_check_seekable(f) | |
f_should_read_directly = _should_read_directly(f) | |
if f_should_read_directly and f.tell() == 0: | |
# legacy_load requires that f has fileno() | |
# only if offset is zero we can attempt the legacy tar file loader | |
try: | |
return legacy_load(f) | |
except tarfile.TarError: | |
if _is_zipfile(f): | |
# .zip is used for torch.jit.save and will throw an un-pickling error here | |
raise RuntimeError( | |
f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None | |
# if not a tarfile, reset file offset and proceed | |
f.seek(0) | |
if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): | |
raise RuntimeError( | |
"torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " | |
f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " | |
"functionality.") | |
magic_number = pickle_module.load(f, **pickle_load_args) | |
if magic_number != MAGIC_NUMBER: | |
raise RuntimeError("Invalid magic number; corrupt file?") | |
protocol_version = pickle_module.load(f, **pickle_load_args) | |
if protocol_version != PROTOCOL_VERSION: | |
raise RuntimeError(f"Invalid protocol version: {protocol_version}") | |
_sys_info = pickle_module.load(f, **pickle_load_args) | |
unpickler = UnpicklerWrapper(f, **pickle_load_args) | |
unpickler.persistent_load = persistent_load | |
result = unpickler.load() | |
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) | |
if torch._guards.active_fake_mode() is None: | |
offset = f.tell() if f_should_read_directly else None | |
for key in deserialized_storage_keys: | |
assert key in deserialized_objects | |
typed_storage = deserialized_objects[key] | |
typed_storage._untyped_storage._set_from_file( | |
f, offset, f_should_read_directly, | |
torch._utils._element_size(typed_storage.dtype)) | |
if offset is not None: | |
offset = f.tell() | |
torch._utils._validate_loaded_sparse_tensors() | |
return result | |
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: | |
# When using encoding='bytes' in Py3, some **internal** keys stored as | |
# strings in Py2 are loaded as bytes. This function decodes them with | |
# ascii encoding, one that Py3 uses by default. | |
# | |
# NOTE: This should only be used on internal keys (e.g., `typename` and | |
# `location` in `persistent_load` below! | |
if isinstance(bytes_str, bytes): | |
return bytes_str.decode('ascii') | |
return bytes_str | |
def _get_restore_location(map_location): | |
if map_location is None: | |
restore_location = default_restore_location | |
elif isinstance(map_location, dict): | |
def restore_location(storage, location): | |
location = map_location.get(location, location) | |
return default_restore_location(storage, location) | |
elif isinstance(map_location, (str, bytes)): | |
def restore_location(storage, location): | |
return default_restore_location(storage, map_location) | |
elif isinstance(map_location, torch.device): | |
def restore_location(storage, location): | |
return default_restore_location(storage, str(map_location)) | |
else: | |
def restore_location(storage, location): | |
result = map_location(storage, location) | |
if result is None: | |
result = default_restore_location(storage, location) | |
return result | |
return restore_location | |
class StorageType: | |
def __init__(self, name): | |
self._dtype = _get_dtype_from_pickle_storage_type(name) | |
def dtype(self): | |
return self._dtype | |
def __str__(self): | |
return f'StorageType(dtype={self.dtype})' | |
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args): | |
restore_location = _get_restore_location(map_location) | |
loaded_storages = {} | |
# check if byteswapping is needed | |
byteordername = 'byteorder' | |
byteorderdata = None | |
if zip_file.has_record(byteordername): | |
byteorderdata = zip_file.get_record(byteordername) | |
if byteorderdata not in [b'little', b'big']: | |
raise ValueError('Unknown endianness type: ' + byteorderdata.decode()) | |
elif get_default_load_endianness() == LoadEndianness.LITTLE or \ | |
get_default_load_endianness() is None: | |
byteorderdata = b'little' | |
elif get_default_load_endianness() == LoadEndianness.BIG: | |
byteorderdata = b'big' | |
elif get_default_load_endianness() == LoadEndianness.NATIVE: | |
pass | |
else: | |
raise ValueError('Invalid load endianness type') | |
if not zip_file.has_record(byteordername) and \ | |
get_default_load_endianness() is None and \ | |
sys.byteorder == 'big': | |
# Default behaviour was changed | |
# See https://github.com/pytorch/pytorch/issues/101688 | |
warnings.warn("The default load endianness for checkpoints without a byteorder mark " | |
"on big endian machines was changed from 'native' to 'little' endian, " | |
"to avoid this behavior please use " | |
"torch.serialization.set_default_load_endianness to set " | |
"the desired default load endianness", | |
UserWarning) | |
def load_tensor(dtype, numel, key, location): | |
name = f'data/{key}' | |
if torch._guards.detect_fake_mode(None) is not None: | |
nbytes = numel * torch._utils._element_size(dtype) | |
storage = torch.UntypedStorage(nbytes, device='meta') | |
elif overall_storage is not None: | |
storage_offset = zip_file.get_record_offset(name) | |
storage = overall_storage[storage_offset:storage_offset + numel] | |
else: | |
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage | |
# swap here if byteswapping is needed | |
if byteorderdata is not None: | |
if byteorderdata.decode() != sys.byteorder: | |
storage.byteswap(dtype) | |
# TODO: Once we decide to break serialization FC, we can | |
# stop wrapping with TypedStorage | |
typed_storage = torch.storage.TypedStorage( | |
wrap_storage=restore_location(storage, location), | |
dtype=dtype, | |
_internal=True) | |
if typed_storage._data_ptr() != 0: | |
loaded_storages[key] = typed_storage | |
return typed_storage | |
def persistent_load(saved_id): | |
assert isinstance(saved_id, tuple) | |
typename = _maybe_decode_ascii(saved_id[0]) | |
data = saved_id[1:] | |
assert typename == 'storage', \ | |
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" | |
storage_type, key, location, numel = data | |
if storage_type is torch.UntypedStorage: | |
dtype = torch.uint8 | |
else: | |
dtype = storage_type.dtype | |
if key in loaded_storages: | |
typed_storage = loaded_storages[key] | |
else: | |
nbytes = numel * torch._utils._element_size(dtype) | |
typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) | |
return typed_storage | |
load_module_mapping: Dict[str, str] = { | |
# See https://github.com/pytorch/pytorch/pull/51633 | |
'torch.tensor': 'torch._tensor' | |
} | |
# Need to subclass Unpickler instead of directly monkey-patching the find_class method | |
# because it's marked readonly in pickle. | |
# The type: ignore is because mypy can't statically determine the type of this class. | |
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] | |
# from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 | |
# Lets us override the imports that pickle uses when unpickling an object. | |
# This is useful for maintaining BC if we change a module path that tensor instantiation relies on. | |
def find_class(self, mod_name, name): | |
if type(name) is str and 'Storage' in name: | |
try: | |
return StorageType(name) | |
except KeyError: | |
pass | |
mod_name = load_module_mapping.get(mod_name, mod_name) | |
return super().find_class(mod_name, name) | |
# Load the data (which may in turn use `persistent_load` to load tensors) | |
data_file = io.BytesIO(zip_file.get_record(pickle_file)) | |
unpickler = UnpicklerWrapper(data_file, **pickle_load_args) | |
unpickler.persistent_load = persistent_load | |
result = unpickler.load() | |
torch._utils._validate_loaded_sparse_tensors() | |
torch._C._log_api_usage_metadata( | |
"torch.load.metadata", {"serialization_id": zip_file.serialization_id()} | |
) | |
return result | |
def _is_torchscript_zip(zip_file): | |
return 'constants.pkl' in zip_file.get_all_records() | |