File size: 340 Bytes
e0202f8 |
1 2 3 4 5 6 7 8 |
import torch
def center_trim(tensor, reference):
ref_size = reference.size(-1) if isinstance(reference, torch.Tensor) else reference
delta = tensor.size(-1) - ref_size
if delta < 0: raise ValueError(f"tensor > parameter: {delta}.")
if delta: tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
return tensor |