Spaces:
Sleeping
Sleeping
import torch.fx as fx | |
def set_trace(gm: fx.GraphModule) -> fx.GraphModule: | |
""" | |
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when | |
`gm` gets run. | |
Args: | |
gm: graph module to insert breakpoint. It is then recompiled for it to | |
take effect. | |
Returns: | |
the `gm` with breakpoint inserted. | |
""" | |
def insert_pdb(body): | |
return ["import pdb; pdb.set_trace()\n", *body] | |
with gm.graph.on_generate_code( | |
make_transformer=lambda cur_transform: ( | |
# new code transformer to register | |
lambda body: ( | |
insert_pdb( | |
cur_transform(body) if cur_transform | |
else body | |
) | |
) | |
) | |
): | |
gm.recompile() | |
return gm | |