Spaces:
Sleeping
Sleeping
| import torch.fx as fx | |
| def set_trace(gm: fx.GraphModule) -> fx.GraphModule: | |
| """ | |
| Sets a breakpoint in `gm`'s generated python code. It drops into pdb when | |
| `gm` gets run. | |
| Args: | |
| gm: graph module to insert breakpoint. It is then recompiled for it to | |
| take effect. | |
| Returns: | |
| the `gm` with breakpoint inserted. | |
| """ | |
| def insert_pdb(body): | |
| return ["import pdb; pdb.set_trace()\n", *body] | |
| with gm.graph.on_generate_code( | |
| make_transformer=lambda cur_transform: ( | |
| # new code transformer to register | |
| lambda body: ( | |
| insert_pdb( | |
| cur_transform(body) if cur_transform | |
| else body | |
| ) | |
| ) | |
| ) | |
| ): | |
| gm.recompile() | |
| return gm | |