File size: 5,204 Bytes
357c94c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from typing import Union, List
from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd

from itertools import repeat
import collections.abc


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            x = tuple(x)
            if len(x) == 1:
                x = tuple(repeat(x[0], n))
            return x
        return tuple(repeat(x, n))
    return parse

to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)

def get_rope_freq_from_size(latents_size, ndim, target_ndim, args,
                            rope_theta_rescale_factor: Union[float, List[float]]=1.0,
                            rope_interpolation_factor: Union[float, List[float]]=1.0,
                            concat_dict={}):
                            
    if isinstance(args.patch_size, int):
        assert all(s % args.patch_size == 0 for s in latents_size), \
            f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
            f"but got {latents_size}."
        rope_sizes = [s // args.patch_size for s in latents_size]
    elif isinstance(args.patch_size, list):
        assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
            f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
            f"but got {latents_size}."
        rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)]

    if len(rope_sizes) != target_ndim:
        rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
    head_dim = args.hidden_size // args.num_heads
    rope_dim_list = args.rope_dim_list
    if rope_dim_list is None:
        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
    assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
    freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, 
                                                    rope_sizes, 
                                                    theta=args.rope_theta, 
                                                    use_real=True,
                                                    theta_rescale_factor=rope_theta_rescale_factor,
                                                    interpolation_factor=rope_interpolation_factor,
                                                    concat_dict=concat_dict)
    return freqs_cos, freqs_sin
    
def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False, 
                            theta_rescale_factor: Union[float, List[float]]=1.0,
                            interpolation_factor: Union[float, List[float]]=1.0,
                            concat_dict={}
                            ):

    grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))   # [3, W, H, D] / [2, W, H]
    if len(concat_dict)<1:
        pass
    else:
        if concat_dict['mode']=='timecat':
            bias = grid[:,:1].clone()
            bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
            grid = torch.cat([bias, grid], dim=1)
            
        elif concat_dict['mode']=='timecat-w': 
            bias = grid[:,:1].clone()
            bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
            bias[2] += start[-1]    ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178
            grid = torch.cat([bias, grid], dim=1)
    if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
    elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
    assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"

    if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
    elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
    assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"

    # use 1/ndim of dimensions to encode grid_axis
    embs = []
    for i in range(len(rope_dim_list)):
        emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
                                      theta_rescale_factor=theta_rescale_factor[i],
                                      interpolation_factor=interpolation_factor[i])    # 2 x [WHD, rope_dim_list[i]]
        
        embs.append(emb)

    if use_real:
        cos = torch.cat([emb[0] for emb in embs], dim=1)    # (WHD, D/2)
        sin = torch.cat([emb[1] for emb in embs], dim=1)    # (WHD, D/2)
        return cos, sin
    else:
        emb = torch.cat(embs, dim=1)    # (WHD, D/2)
        return emb