Spaces:
Running
Running
import torch | |
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner | |
from torch.fx.passes.operator_support import OperatorSupport | |
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS | |
from torch.fx.passes.fake_tensor_prop import FakeTensorProp | |
from torch.utils import _pytree as pytree | |
import operator | |
class CudaGraphsSupport(OperatorSupport): | |
# TODO: why is submodules passed here | |
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: | |
if node.op not in CALLABLE_NODE_OPS: | |
return False | |
if node.target in [torch.ops.aten.embedding_dense_backward.default]: | |
return False | |
if node.target in [operator.getitem]: | |
return True | |
found_not_cuda = False | |
def meta_fk(meta): | |
return meta["val"] if "val" in meta else meta["fake_result"] | |
def find_not_cuda(t): | |
nonlocal found_not_cuda | |
if isinstance(t, torch.Tensor) and t.device.type != 'cuda': | |
found_not_cuda = True | |
for n in node.all_input_nodes: | |
pytree.tree_map_(find_not_cuda, meta_fk(n.meta)) | |
pytree.tree_map_(find_not_cuda, meta_fk(node.meta)) | |
# NB: factory function is accounted for because the result would be | |
# cpu or cuda | |
return not found_not_cuda | |
def partition_cudagraphs(gm, inputs): | |
""" | |
Partition an FX graph into sub-GraphModules that can be validly run under | |
CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations | |
must involve CUDA tensors only/ | |
""" | |
FakeTensorProp(gm).propagate(*inputs) | |
supported_ops = CudaGraphsSupport() | |
# TODO: single node partition may be wrong due to the pessimization | |
# from copying in and out the data. Check in benchmarks, perhaps | |
partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) | |
partitions = partitioner.propose_partitions() | |
fused_graph = partitioner.fuse_partitions(partitions) | |
return fused_graph | |