File size: 29,641 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
# mypy: ignore-errors

import logging
import traceback
from dataclasses import dataclass, field
from typing import Any, List, Optional
from unittest import mock

import torch
from torch import fx
from torch._dynamo.output_graph import GraphCompileReason
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
from torch._logging import trace_structured
from torch.fx.node import Node

# Regular log messages should go through 'log'.
# ddp_graph_log is a separate artifact logger reserved for dumping graphs.
# See docs/source/logging.rst for more info.
log = logging.getLogger(__name__)
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")


def args_str(args):
    # a debug helper
    if torch.is_tensor(args):
        return f"T[{args.shape}]"
    elif isinstance(args, tuple):
        return f"tuple({', '.join([args_str(x) for x in args])})"
    elif isinstance(args, list):
        return f"list({', '.join([args_str(x) for x in args])})"
    else:
        return str(args)


@dataclass
class Bucket:
    size: int = 0
    params: List[str] = field(default_factory=list)
    nodes: List[fx.Node] = field(default_factory=list)

    # param_ids is just used for unit testing
    param_ids: List = field(default_factory=list)

    # keep track of any buckets that were extended for logging purposes
    opcount_increased_to_capture_external_output: int = 0
    paramsize_before_opcount_increase: int = 0


def bucket_has_external_output(bucket: Bucket) -> bool:
    nodes_in_bucket = set()
    # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
    # so we don't reverse it here
    for node in bucket.nodes:
        # assume node.op != output, since those are filtered in the original iteration
        nodes_in_bucket.add(node)
        for user in node.users:
            if user not in nodes_in_bucket:
                return True
    return False


def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
    headers = ("Index", "Size (b)", "Param Names")
    rows = []
    extended_buckets = []
    for idx, bucket in enumerate(reversed(buckets)):
        if len(bucket.params) > 0:
            rows.append((idx, bucket.size, bucket.params[0]))
            for param in bucket.params[1:]:
                rows.append((None, None, param))
        if bucket.opcount_increased_to_capture_external_output > 0:
            extended_buckets.append(
                (
                    idx,
                    bucket.opcount_increased_to_capture_external_output,
                    bucket.size - bucket.paramsize_before_opcount_increase,
                )
            )

    if len(rows):
        log.info(
            "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
            bucket_bytes_cap,
            len(buckets),
        )

        if len(extended_buckets):
            log.warning(
                "Some buckets were extended beyond their requested parameter capacities"
                " in order to ensure each subgraph has an output node, required for fx graph partitioning."
                " This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
                " and returning no logical outputs. This should not be a problem, unless it results in too few graph"
                " partitions for optimal DDP performance."
            )

        try:
            from tabulate import tabulate

            log.debug(
                "\nDDPOptimizer produced the following bucket assignments:\n%s",
                tabulate(rows, headers=headers, tablefmt="simple_grid"),
            )

            if len(extended_buckets):
                log.warning(
                    "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
                    tabulate(
                        extended_buckets,
                        headers=("Index", "Extra Ops", "Extra Param Size (b)"),
                        tablefmt="simple_grid",
                    ),
                )
        except ImportError:
            log.debug(
                "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
            )
    else:
        log.debug("DDPOptimizer captured no parameters and did not split this graph.")


def has_higher_order_op(gm):
    # Check if there is a higher order op in the graph
    for node in gm.graph.nodes:
        if node.op == "get_attr":
            maybe_param = getattr(gm, node.target)
            if isinstance(maybe_param, torch.fx.GraphModule):
                return True
    return False


# 3 (lazy compile): Replace submodules with lazily compiling submodule
class SubmoduleReplacer(torch.fx.interpreter.Interpreter):
    def __init__(self, module, compiler):
        super().__init__(module)
        self.compiler = compiler

    def lazily_compiled_submod(self, input_mod):
        """

        Create a wrapper around submodules which:

        - lazily compiles each of the partitioned submodules using the user-provided compiler

        - unpacks singleton tuples/lists into flat arg

        """

        class LazilyCompiledModule(torch.nn.Module):
            def __init__(self, submod, compiler, unwrap_singleton_tuple):
                super().__init__()
                self.submod = submod
                self.compiler = compiler
                self.compiled = False
                self.unwrap_singleton_tuple = unwrap_singleton_tuple

            def forward(self, *args):
                if not self.compiled:
                    # First compile with args as example_inputs
                    # These args will be fakeified if using Inductor/AOTAutograd
                    new_submod = self.compiler(self.submod, args)
                    del self.submod
                    self.submod = new_submod
                    self.compiled = True
                    self.compiler = None

                x = self.submod(*args)
                # we must let 'input_mod' return a tuple, to make AOT happy.
                # (aot_autograd compile_fn literally requires that the output of a graph it compiles is a tuple).
                # however, we don't acutally want this tuple to be returned, since the fx logic that calls the submod
                # will again wrap outputs from the submod in a tuple.  So we unwrap it, and count on it being re-wrapped
                if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
                    return x[0]
                return x

        unwrap_singleton_tuple = False
        for sn in input_mod.graph.nodes:
            if sn.op == "output":
                if not isinstance(sn.args[0], tuple):
                    unwrap_singleton_tuple = True
                    sn.args = (sn.args,)

        input_mod.recompile()
        input_mod.compile_subgraph_reason = GraphCompileReason(
            "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
            " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
            [
                # it's close to useless to get a real stacktrace here, and quite verbose.
                traceback.FrameSummary(__file__, 0, DDPOptimizer),
            ],
        )
        wrapper = LazilyCompiledModule(
            input_mod,
            self.compiler,
            unwrap_singleton_tuple,
        )
        return wrapper

    # We replace the submodules with lazy submodules which compile
    # the corresponding submodules when they are run with real values
    # Always returns `None` - we do not need to propagate values in order
    # to replace submodules.
    def run_node(self, n: Node) -> Any:
        if n.op == "call_module":
            real_mod = self.fetch_attr(n.target)

            ddp_graph_log.debug("\n---%s graph---\n%s", n.target, real_mod.graph)

            assert len(n.kwargs) == 0, "We assume only args for these modules"
            lazily_compiled_submod = self.lazily_compiled_submod(real_mod)

            # We update the original (outer) graph with a call into the compiled module
            # instead of the uncompiled one.
            self.module.delete_submodule(n.target)
            n.target = "compiled_" + n.target
            self.module.add_submodule(n.target, lazily_compiled_submod)


# 3 (no lazy compile): compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
    def __init__(self, module, compiler, fake_mode):
        super().__init__(module)
        self.compiler = compiler
        self.fake_mode = fake_mode

    def compile_submod(self, input_mod, args, kwargs):
        """

        Compile the submodule,

        using a wrapper to make sure its output is always a tuple,

        which is required by AotAutograd based compilers

        """
        assert len(kwargs) == 0, "We assume only args for these modules"

        class WrapperModule(torch.nn.Module):
            def __init__(self, submod, unwrap_singleton_tuple):
                super().__init__()
                self.submod = submod
                self.unwrap_singleton_tuple = unwrap_singleton_tuple

            def forward(self, *args):
                x = self.submod(*args)
                # TODO(whc)
                # for some reason the isinstance check is necessary if I split one node per submod
                # - even though I supposedly wrapped the output in a tuple in those cases, the real
                # compiled module was still returning a tensor
                if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
                    return x[0]
                return x

        unwrap_singleton_tuple = False
        for sn in input_mod.graph.nodes:
            if sn.op == "output":
                if not isinstance(sn.args[0], tuple):
                    unwrap_singleton_tuple = True
                    sn.args = (sn.args,)

        input_mod.recompile()
        input_mod.compile_subgraph_reason = GraphCompileReason(
            "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
            " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
            [
                # it's close to useless to get a real stacktrace here, and quite verbose.
                traceback.FrameSummary(__file__, 0, DDPOptimizer),
            ],
        )

        wrapper = WrapperModule(
            self.compiler(input_mod, args),
            unwrap_singleton_tuple,
        )
        return wrapper

    # Note:
    #
    # The way distributed works today around fake tensors can be somewhat confusing.
    # Some of these codepaths are shared in both runtime, and compile time. The presence
    # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
    #
    # A few things to keep in mind:
    #
    # 1) We invoke `compile_submod` with a real module. The output of that gets stored
    # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
    #
    # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
    # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
    #
    # 3) Fake tensors should always be around during compile time.
    #
    # 4) Fake tensors should never be around at runtime.
    #
    # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
    # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
    def run_node(self, n: Node) -> Any:
        args, kwargs = self.fetch_args_kwargs_from_env(n)
        new_args = []
        assert self.fake_mode
        for arg in args:
            if isinstance(arg, torch.Tensor) and not isinstance(
                arg, torch._subclasses.FakeTensor
            ):
                new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
            else:
                new_args.append(arg)

        log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
        assert isinstance(args, tuple)
        assert isinstance(kwargs, dict)

        if n.op == "call_module":
            real_mod = self.fetch_attr(n.target)
            if self.fake_mode:
                curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
            else:
                curr_submod = real_mod

            ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)

            # When calling the compiler on the submod, inputs (new_args) are expected to
            # be FakeTensors already since Dynamo would have made them FakeTensors in the
            # non-DDP flow.  However, the parameters are _not_ expected to be FakeTensors,
            # since this wrapping happens during compilation

            # Note: Returning Fake Tensors on First AOT Autograd Call
            #
            # Inductor will optimize strides of outputs when it deems it profitable.
            # For instance, converting to channels last. When we split the graph here
            # into multiple inductor compilations, we need to make sure that the
            # output strides of one compilation is appropriately passed to the subsequent
            # compilations. However, the mapping from inductor output to dynamo output
            # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
            # subclass handling, etc. In order to replay all this logic we set a flag such that
            # the first invocation of inductor in aot_autograd will return Fake Tensors with
            # appropriate strides. Then, all of aot autograd's runtime logic is replayed.
            # This gives us the appropriately strided outputs here which will reflect runtime strides.

            class FakeifyFirstAOTInvocationGuard:
                def __init__(self):
                    self.tc = torch._guards.TracingContext.try_get()
                    assert self.tc
                    torch._guards.TracingContext.try_get().fakify_first_call = True

                def __del__(self):
                    self.tc.fakify_first_call = False

            # For aot_eager and other backends, tracing context is not set
            has_tracing_context = torch._guards.TracingContext.try_get() is not None
            if has_tracing_context:
                g = FakeifyFirstAOTInvocationGuard()

            from torch._dynamo.utils import counters

            init = counters["aot_autograd"]["total"]
            compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)

            # TODO - better way of doing this?
            # Only aot autograd handles fakifying first call
            invoked_aot_autograd = init != counters["aot_autograd"]["total"]

            # We update the original (outer) graph with a call into the compiled module
            # instead of the uncompiled one.
            self.module.delete_submodule(n.target)
            n.target = "compiled_" + n.target
            self.module.add_submodule(n.target, compiled_submod_real)

            # Finally, we have to produce inputs for use compiling the next submodule,
            # and these need to be FakeTensors, so we execute the module under fake_mode
            # Because parameters are not fake we patch fake tensor mode to allow non fake inputs
            with self.fake_mode, mock.patch.object(
                self.fake_mode, "allow_non_fake_inputs", True
            ):
                if has_tracing_context and invoked_aot_autograd:
                    out = compiled_submod_real(*new_args, **kwargs)
                    # output should be fake or subclass
                    assert all(
                        (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
                        for t in (out if isinstance(out, (list, tuple)) else [out])
                    )
                    return out
                else:
                    return curr_submod(*new_args, **kwargs)
        else:
            # placeholder or output nodes don't need to get compiled, just executed
            return getattr(self, n.op)(n.target, new_args, kwargs)


class DDPOptimizer:

    """Note [DDPOptimizer]

    DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),

    breaking the dynamo graph into chunks to compile separately, with the breaks aligning to

    the boundaries of gradient-allreduce buckets chosen by DDP.



    Background/Motivation

     - DDP uses allreduce collectives to synchronize partial gradients computed on different workers

     - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce

     - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready

       at around the same time during backward and thus can share the same allreduce efficiently

     - Allreduces must overlap with backward compute for optimal training performance

     - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which

       operates when individual grads become 'ready'

     - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the

       autograd engine, such that all gradients become 'ready' at the same time.  Hooks fire after the whole

       fused backward function executes, preventing any overlap of compute and communication



    Algorithm

     - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward.  It can traverse

       this graph in reverse order to determine the true order that gradients will become ready during backward.

     - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started

       and a graph break introduced

     - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together

       into an outer module that is returned to the user



    Notes

     - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,

       and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does

       in eager.

     - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently

       produce splits that do not necessarily align with the buckets used by DDP.  This should result in performance

       degradation approaching the baseline case where graph-splits are not used, but not worse.

     - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the

       subgraphs being compiled

     - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers

       left by DDP on individual parameters.  In cases where other transformations, such as reparameterization, are

       also used, the ignore markers could be lost.  If DDPOptimizer fails to ignore a parameter ignored by DDP,

       it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.

     - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,

       and therefore aren't allreduced by DDP.  (They are broadcast during forward, but this is not covered by

       DDPOptimizer)



    Debugging

     - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.

     - In many cases, the log messages are helpful (they show bucket size assignments)-

       just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.

     - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model

       in a single process (or with torchrun, in multiple processes)



    Args:

        bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks.  Should be

            set to match the equivalent parameter on the original DDP module.



        backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.



        first_bucket_cap (int): Controls the size of the first bucket.  Should match DDP's first bucket cap.  DDP

            special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.



    """

    def __init__(

        self,

        bucket_bytes_cap: int,

        backend_compile_fn,

        first_bucket_cap: Optional[int] = None,

    ):
        if first_bucket_cap is not None:
            self.first_bucket_cap = first_bucket_cap
        elif torch.distributed.is_available():
            # this constant comes from C10D lib which is not always built
            self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
        else:
            self.first_bucket_cap = bucket_bytes_cap

        self.bucket_bytes_cap = bucket_bytes_cap
        assert (
            self.first_bucket_cap <= self.bucket_bytes_cap
        ), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"

        self.backend_compile_fn = backend_compile_fn

    def _ignore_parameter(self, parameter):
        return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored

    def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
        """

        Implements graph splitting, first determining a set of of buckets by counting

        parameter sizes in reverse graph order, then invoking the user/backend compiler

        to compile each subgraph. Finally, stiches compiled graphs into one graphmodule

        and returns its callable.

        """
        if has_higher_order_op(gm):
            # This indicates presence of a higher order op. For now, we
            # have no way to break the higher order op into two buckets.
            # Allowing higher order ops in the graph also requires
            # changes in the split_module, becuase graph splitter
            # currently assumes that all the args of all ops are
            # tensors, but in the case of higher order ops, it could be
            # a graph module. As a workaround, we are shortcircuiting
            raise NotImplementedError(
                "DDPOptimizer backend: Found a higher order op in the graph. "
                "This is not supported. Please turn off DDP optimizer using "
                "torch._dynamo.config.optimize_ddp=False. Note that this can "
                "cause performance degradation because there will be one bucket "
                "for the entire Dynamo graph. Please refer to this issue - "
                "https://github.com/pytorch/pytorch/issues/104674."
            )

        # 1: compute the partition map according to DDP bucket logic
        buckets = [Bucket()]  # (size, param_names)
        for node in reversed(gm.graph.nodes):
            if node.op in ("output", "placeholder"):
                continue

            if (
                buckets[0].size >= self.bucket_bytes_cap
                or len(buckets) == 1
                and buckets[0].size >= self.first_bucket_cap
            ):
                if bucket_has_external_output(buckets[0]):
                    buckets.insert(0, Bucket())
                else:
                    # continue building this bucket past the point of filling its parameter capacity,
                    # to increase chances it contains at least one node that is either a global output or
                    # passed as input to a subsequent graph

                    if buckets[0].opcount_increased_to_capture_external_output == 0:
                        buckets[0].paramsize_before_opcount_increase = buckets[0].size
                    buckets[0].opcount_increased_to_capture_external_output += 1

            if node.op == "call_module":
                target = gm.get_submodule(node.target)
                for name, param in target.named_parameters():
                    if param.requires_grad and not self._ignore_parameter(param):
                        buckets[0].size += param.untyped_storage().nbytes()
                        buckets[0].params.append(f"{node.target}_{name}")
                        buckets[0].param_ids.append(id(param))
            elif node.op == "get_attr":
                maybe_param = getattr(gm, node.target)
                if maybe_param.requires_grad and not self._ignore_parameter(
                    maybe_param
                ):
                    buckets[0].size += maybe_param.untyped_storage().nbytes()
                    buckets[0].params.append(node.target)
                    buckets[0].param_ids.append(id(maybe_param))

            # All nodes have to be mapped to a bucket, even if they don't have their own params
            # Ignored params still end up in buckets, we just don't count them towards the capacity
            buckets[0].nodes.append(node)

        if len(buckets) > 1 and buckets[0].size == 0:
            # we collected a small preamble graph with ops that don't include parameters, fuse it back
            buckets[1].nodes.extend(buckets[0].nodes)
            assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
            del buckets[0]

        # stash buckets for testing/debugging purposes
        self.buckets = buckets
        pretty_print_buckets(buckets, self.bucket_bytes_cap)

        if len(buckets) == 1:
            # bypass split/fuse logic if there is only one bucket
            return self.backend_compile_fn(gm, example_inputs)

        # 2: partition the graphmodule according to bucket capacity
        partition_map = {}
        for idx, b in enumerate(buckets):
            for node in b.nodes:
                partition_map[node] = idx

        split_gm = fx.passes.split_module.split_module(
            gm, None, lambda node: partition_map[node]
        )

        debug_str = (
            f"\n---orig graph---\n{gm.graph}\n"
            + f"\n---split graph---\n{split_gm.graph}\n"
        )
        for name, module in split_gm.named_modules():
            if "." not in name and len(name):
                # only print the submod graphs, not their children
                debug_str += f"\n---{name} graph---\n{module.graph}\n"
        debug_str += "\n---------------\n"
        ddp_graph_log.debug(debug_str)

        trace_structured(
            "optimize_ddp_split_graph",
            payload_fn=lambda: split_gm.print_readable(print_output=False),
        )
        for name, module in split_gm.named_modules():
            if "." not in name and len(name):
                trace_structured(
                    "optimize_ddp_split_child",
                    lambda: {"name": name},
                    payload_fn=lambda: module.print_readable(print_output=False),
                )

        # NOTE, we want to enable `optimize_ddp_lazy_compile` by default as soon as possible,
        # becuase it will fix stride mismatch errors (see motivation: https://github.com/pytorch/pytorch/pull/114154).
        # However, lazy compile currently causes shape mismatch in other cases (`test_graph_split_inductor_transpose`)
        # and we need to fix them before we can enable it by default.
        if not torch._dynamo.config.optimize_ddp_lazy_compile:
            # Today, optimize_ddp=True and keep_output_stride=False can lead to silent
            # correctness issues. The problem is that ddp_optimizer works by partitioning
            # the dynamo graph, sending each subgraph through aot autograd to inductor,
            # and creates example inputs by eagerly interpreting each subgraph to get
            # an output that with the same metadata that we'd get from eager mode.
            # This is a problem though, for torch._inductor.config.keep_output_stride.
            # The above config can cause the outputs of the first graph to have
            # **different** strides from eager, causing the inputs that we pass
            # to the second graph to be wrong.
            # To really fix this, we would need to faithfully ask inductor
            # what the outputs to each graph it expects are.
            fake_mode = detect_fake_mode(example_inputs)
            if fake_mode is None:
                fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()

        if torch._dynamo.config.optimize_ddp_lazy_compile:
            submod_compiler = SubmoduleReplacer(split_gm, self.backend_compile_fn)
        else:
            submod_compiler = SubmodCompiler(
                split_gm, self.backend_compile_fn, fake_mode
            )
        submod_compiler.run(*example_inputs)
        split_gm.recompile()

        ddp_graph_log.debug(
            "\n---final graph---\n%s\n---------------\n", split_gm.graph
        )
        return split_gm