Spaces:
Running
Running
File size: 1,781 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import types
from contextlib import contextmanager
# The idea for this parameter is that we forbid bare assignment
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
# test suite, where it's very easy to forget to undo the change
# later.
__allow_nonbracketed_mutation_flag = True
def disable_global_flags():
global __allow_nonbracketed_mutation_flag
__allow_nonbracketed_mutation_flag = False
def flags_frozen():
return not __allow_nonbracketed_mutation_flag
@contextmanager
def __allow_nonbracketed_mutation():
global __allow_nonbracketed_mutation_flag
old = __allow_nonbracketed_mutation_flag
__allow_nonbracketed_mutation_flag = True
try:
yield
finally:
__allow_nonbracketed_mutation_flag = old
class ContextProp:
def __init__(self, getter, setter):
self.getter = getter
self.setter = setter
def __get__(self, obj, objtype):
return self.getter()
def __set__(self, obj, val):
if not flags_frozen():
self.setter(val)
else:
raise RuntimeError(
"not allowed to set %s flags "
"after disable_global_flags; please use flags() context manager instead"
% obj.__name__
)
class PropModule(types.ModuleType):
def __init__(self, m, name):
super().__init__(name)
self.m = m
def __getattr__(self, attr):
return self.m.__getattribute__(attr)
from torch.backends import (
cpu as cpu,
cuda as cuda,
cudnn as cudnn,
mha as mha,
mkl as mkl,
mkldnn as mkldnn,
mps as mps,
nnpack as nnpack,
openmp as openmp,
quantized as quantized,
)
|