File size: 1,903 Bytes
eb339cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch

from mld.utils.temos_utils import lengths_to_mask


def collate_tensors(batch: list) -> torch.Tensor:
    dims = batch[0].dim()
    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
    size = (len(batch), ) + tuple(max_size)
    canvas = batch[0].new_zeros(size=size)
    for i, b in enumerate(batch):
        sub_tensor = canvas[i]
        for d in range(dims):
            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
        sub_tensor.add_(b)
    return canvas


def mld_collate(batch: list) -> dict:
    notnone_batches = [b for b in batch if b is not None]
    notnone_batches.sort(key=lambda x: x[3], reverse=True)
    adapted_batch = {
        "motion":
        collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]),
        "text": [b[2] for b in notnone_batches],
        "length": [b[5] for b in notnone_batches],
        "word_embs":
        collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]),
        "pos_ohot":
        collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]),
        "text_len":
        collate_tensors([torch.tensor(b[3]) for b in notnone_batches]),
        "tokens": [b[6] for b in notnone_batches]
    }

    mask = lengths_to_mask(adapted_batch['length'], adapted_batch['motion'].device, adapted_batch['motion'].shape[1])
    adapted_batch['mask'] = mask

    # collate trajectory
    if notnone_batches[0][-1][0] is not None:
        adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1][0]).float() for b in notnone_batches])
        adapted_batch['hint_mask'] = collate_tensors([torch.tensor(b[-1][1]).float() for b in notnone_batches])

    return adapted_batch


def mld_collate_motion_only(batch: list) -> dict:
    batch = {
        "motion": collate_tensors([torch.tensor(b[0]).float() for b in batch]),
        "length": [b[1] for b in batch]
    }
    return batch