File size: 348 Bytes
1e4a2ab
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import torch

def center_trim(tensor, reference):
    ref_size = reference.size(-1) if isinstance(reference, torch.Tensor) else reference
    delta = tensor.size(-1) - ref_size
    
    if delta < 0: raise ValueError(f"tensor > parameter: {delta}.")
    if delta: tensor = tensor[..., delta // 2 : -(delta - delta // 2)]

    return tensor