Spaces:
Running
Running
File size: 2,051 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import contextlib
from typing import Optional, Sequence
import torch
from torch._custom_op.impl import custom_op
from torch.utils._content_store import ContentStoreReader
LOAD_TENSOR_READER: Optional[ContentStoreReader] = None
@contextlib.contextmanager
def load_tensor_reader(loc):
global LOAD_TENSOR_READER
assert LOAD_TENSOR_READER is None
# load_tensor is an "op", and we will play merry hell on
# Inductor's memory planning if we return a tensor that
# aliases another tensor that we previously returned from
# an operator. So unlike standard ContentStoreReader use,
# we disable the cache so that you always get fresh storages
# (no aliasing for you!)
LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
try:
yield
finally:
LOAD_TENSOR_READER = None
def register_debug_prims():
@custom_op("debugprims::load_tensor")
def load_tensor( # type: ignore[empty-body]
name: str,
size: Sequence[int],
stride: Sequence[int],
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
...
@load_tensor.impl_factory()
def load_tensor_factory(name, size, stride, dtype, device):
if LOAD_TENSOR_READER is None:
from torch._dynamo.testing import rand_strided
return rand_strided(size, stride, dtype, device)
else:
from torch._dynamo.utils import clone_input
# device argument here takes care of coercion
r = LOAD_TENSOR_READER.read_tensor(name, device=device)
assert list(r.size()) == size, f"{r.size()} != {size}"
assert list(r.stride()) == stride, f"{r.stride()} != {stride}"
assert r.device == device, f"{r.device} != {device}"
# Unlike the other properties, we will do coercions for dtype
# mismatch
if r.dtype != dtype:
r = clone_input(r, dtype=dtype)
return r
|