Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from torch._subclasses.fake_tensor import _is_tensor_constructor | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| class MetaTensorContext(TorchDispatchMode): | |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
| if _is_tensor_constructor(func): | |
| device_idx = [arg.name | |
| for arg in func._schema.arguments].index('device') | |
| if len(args) > device_idx: | |
| args = list(args) | |
| args[device_idx] = 'meta' | |
| else: | |
| kwargs['device'] = 'meta' | |
| return func(*args, **kwargs) | |