pal-b-large-opt-350m / tensor_merger.py
daiweichen's picture
Upload PAL_B_RM_opt
3a2aa34 verified
import torch
class TensorMerger:
def __init__(self, merger_type) -> None:
self.merger_type = merger_type
def concat(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.cat([x, y], dim=1)
def __call__(self, x: torch.Tensor, y: torch.Tensor):
if self.merger_type == 'concat':
return self.concat(x,y)
else:
raise ValueError(f'Unknown merger type: {self.merger_type}')