Spaces:
Sleeping
Sleeping
# Mypy will not try inferring the types of any 3rd party libraries installed. | |
# mypy: ignore-errors | |
import io | |
import os | |
from contextlib import contextmanager | |
from pathlib import Path | |
from typing import Generator, Optional, Union | |
import fsspec | |
from fsspec import AbstractFileSystem | |
from fsspec.core import url_to_fs | |
from torch.distributed.checkpoint.filesystem import ( | |
FileSystemBase, | |
FileSystemReader, | |
FileSystemWriter, | |
) | |
__all__ = [ | |
"FsspecWriter", | |
"FsspecReader", | |
] | |
class FileSystem(FileSystemBase): | |
def __init__(self) -> None: | |
self.fs: Optional[AbstractFileSystem] = None | |
def create_stream( | |
self, path: Union[str, os.PathLike], mode: str | |
) -> Generator[io.IOBase, None, None]: | |
assert self.fs is not None | |
with self.fs.transaction: | |
with fsspec.open(str(path), mode) as stream: | |
yield stream | |
def concat_path( | |
self, path: Union[str, os.PathLike], suffix: str | |
) -> Union[str, os.PathLike]: | |
return os.path.join(path, suffix) | |
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: | |
self.fs, _ = url_to_fs(path) | |
return path | |
def rename( | |
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] | |
) -> None: | |
self.fs.rename(path, new_path) | |
def mkdir(self, path: [str, os.PathLike]) -> None: | |
self.fs.makedirs(path, exist_ok=True) | |
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: | |
if isinstance(checkpoint_id, Path): | |
return False | |
try: | |
url_to_fs(checkpoint_id) | |
except ValueError as e: | |
return False | |
return True | |
class FsspecWriter(FileSystemWriter): | |
""" | |
Basic implementation of StorageWriter using FFspec. | |
This implementation makes the following assumptions and simplifications: | |
* The checkpoint path is an empty or non-existing directory. | |
* File creation is atomic | |
The checkpoint consist of one file per write request plus | |
a `.metadata` file with the serialized metadata. | |
""" | |
def __init__( | |
self, | |
path: Union[str, os.PathLike], | |
single_file_per_rank: bool = True, | |
sync_files: bool = True, | |
thread_count: int = 1, | |
per_thread_copy_ahead: int = 10_000_000, | |
) -> None: | |
""" | |
Initialize the writer pointing to `path`. | |
Args: | |
path: directory where the checkpoint will be written to. | |
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. | |
sync_files : force files to be synced to permanent storage. Default to True. | |
thread_count: Number of IO threads to use to write. Default to 1. | |
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. | |
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. | |
""" | |
super().__init__( | |
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead | |
) | |
self.fs = FileSystem() | |
self.path = self.fs.init_path(path) | |
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: | |
return FileSystem.validate_checkpoint_id(checkpoint_id) | |
class FsspecReader(FileSystemReader): | |
def __init__(self, path: Union[str, os.PathLike]) -> None: | |
super().__init__(path) | |
self.fs = FileSystem() | |
self.path = self.fs.init_path(path) | |
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: | |
return FileSystem.check(checkpoint_id) | |