File size: 459 Bytes
3a2aa34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch


class TensorMerger:
    
    def __init__(self, merger_type) -> None:
        self.merger_type = merger_type

    def concat(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return torch.cat([x, y], dim=1)

    def __call__(self, x: torch.Tensor, y: torch.Tensor):
        if self.merger_type == 'concat':
            return self.concat(x,y)
        else:
            raise ValueError(f'Unknown merger type: {self.merger_type}')