File size: 2,051 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import contextlib
from typing import Optional, Sequence

import torch
from torch._custom_op.impl import custom_op
from torch.utils._content_store import ContentStoreReader

LOAD_TENSOR_READER: Optional[ContentStoreReader] = None


@contextlib.contextmanager
def load_tensor_reader(loc):
    global LOAD_TENSOR_READER
    assert LOAD_TENSOR_READER is None
    # load_tensor is an "op", and we will play merry hell on
    # Inductor's memory planning if we return a tensor that
    # aliases another tensor that we previously returned from
    # an operator.  So unlike standard ContentStoreReader use,
    # we disable the cache so that you always get fresh storages
    # (no aliasing for you!)
    LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
    try:
        yield
    finally:
        LOAD_TENSOR_READER = None


def register_debug_prims():
    @custom_op("debugprims::load_tensor")
    def load_tensor(  # type: ignore[empty-body]

        name: str,

        size: Sequence[int],

        stride: Sequence[int],

        *,

        dtype: torch.dtype,

        device: torch.device,

    ) -> torch.Tensor:
        ...

    @load_tensor.impl_factory()
    def load_tensor_factory(name, size, stride, dtype, device):
        if LOAD_TENSOR_READER is None:
            from torch._dynamo.testing import rand_strided

            return rand_strided(size, stride, dtype, device)
        else:
            from torch._dynamo.utils import clone_input

            # device argument here takes care of coercion
            r = LOAD_TENSOR_READER.read_tensor(name, device=device)
            assert list(r.size()) == size, f"{r.size()} != {size}"
            assert list(r.stride()) == stride, f"{r.stride()} != {stride}"
            assert r.device == device, f"{r.device} != {device}"

            # Unlike the other properties, we will do coercions for dtype
            # mismatch
            if r.dtype != dtype:
                r = clone_input(r, dtype=dtype)
            return r