| import torch | |
| class TensorMerger: | |
| def __init__(self, merger_type) -> None: | |
| self.merger_type = merger_type | |
| def concat(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return torch.cat([x, y], dim=1) | |
| def __call__(self, x: torch.Tensor, y: torch.Tensor): | |
| if self.merger_type == 'concat': | |
| return self.concat(x,y) | |
| else: | |
| raise ValueError(f'Unknown merger type: {self.merger_type}') |