import torch | |
class TensorMerger: | |
def __init__(self, merger_type) -> None: | |
self.merger_type = merger_type | |
def concat(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
return torch.cat([x, y], dim=1) | |
def __call__(self, x: torch.Tensor, y: torch.Tensor): | |
if self.merger_type == 'concat': | |
return self.concat(x,y) | |
else: | |
raise ValueError(f'Unknown merger type: {self.merger_type}') |