File size: 1,903 Bytes
eb339cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
from mld.utils.temos_utils import lengths_to_mask
def collate_tensors(batch: list) -> torch.Tensor:
dims = batch[0].dim()
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
size = (len(batch), ) + tuple(max_size)
canvas = batch[0].new_zeros(size=size)
for i, b in enumerate(batch):
sub_tensor = canvas[i]
for d in range(dims):
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
sub_tensor.add_(b)
return canvas
def mld_collate(batch: list) -> dict:
notnone_batches = [b for b in batch if b is not None]
notnone_batches.sort(key=lambda x: x[3], reverse=True)
adapted_batch = {
"motion":
collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]),
"text": [b[2] for b in notnone_batches],
"length": [b[5] for b in notnone_batches],
"word_embs":
collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]),
"pos_ohot":
collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]),
"text_len":
collate_tensors([torch.tensor(b[3]) for b in notnone_batches]),
"tokens": [b[6] for b in notnone_batches]
}
mask = lengths_to_mask(adapted_batch['length'], adapted_batch['motion'].device, adapted_batch['motion'].shape[1])
adapted_batch['mask'] = mask
# collate trajectory
if notnone_batches[0][-1][0] is not None:
adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1][0]).float() for b in notnone_batches])
adapted_batch['hint_mask'] = collate_tensors([torch.tensor(b[-1][1]).float() for b in notnone_batches])
return adapted_batch
def mld_collate_motion_only(batch: list) -> dict:
batch = {
"motion": collate_tensors([torch.tensor(b[0]).float() for b in batch]),
"length": [b[1] for b in batch]
}
return batch
|