| from typing import Callable, Dict, Union | |
| from torch import Tensor | |
| Tree = Union[Dict[str, "Tree"], Tensor] | |
| def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: | |
| """Merge nested dictionaries of tensors.""" | |
| if isinstance(trees[0], Tensor): | |
| return merge_fn(trees) | |
| else: | |
| return { | |
| key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] | |
| } | |