File size: 14,845 Bytes
1674828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
from __future__ import annotations

import math
from typing import Optional

import torch
import torch.nn.functional as F

import numpy as np
from tensorrt_llm._common import default_net
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
from ...functional import (
    Tensor,
    chunk,
    concat,
    constant,
    expand,
    shape,
    silu,
    slice,
    permute,
    expand_mask,
    expand_dims_like,
    unsqueeze,
    matmul,
    softmax,
    squeeze,
    cast,
    gelu,
)
from ...functional import expand_dims, view, bert_attention
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
from ...module import Module


class FeedForward(Module):
    def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim

        self.project_in = Linear(dim, inner_dim)
        self.ff = Linear(inner_dim, dim_out)

    def forward(self, x):
        return self.ff(gelu(self.project_in(x)))


class AdaLayerNormZero(Module):
    def __init__(self, dim):
        super().__init__()

        self.linear = Linear(dim, dim * 6)
        self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, emb=None):
        emb = self.linear(silu(emb))
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
        x = self.norm(x)
        ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
        if default_net().plugin_config.remove_input_padding:
            x = x * (ones + scale_msa) + shift_msa
        else:
            x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormZero_Final(Module):
    def __init__(self, dim):
        super().__init__()

        self.linear = Linear(dim, dim * 2)

        self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, emb):
        emb = self.linear(silu(emb))
        scale, shift = chunk(emb, 2, dim=1)
        ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
        if default_net().plugin_config.remove_input_padding:
            x = self.norm(x) * (ones + scale) + shift
        else:
            x = self.norm(x) * unsqueeze((ones + scale), 1)
            x = x + unsqueeze(shift, 1)
        return x


class ConvPositionEmbedding(Module):
    def __init__(self, dim, kernel_size=31, groups=16):
        super().__init__()
        assert kernel_size % 2 != 0
        self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
        self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
        self.mish = Mish()

    def forward(self, x, mask=None):  # noqa: F722
        if default_net().plugin_config.remove_input_padding:
            x = unsqueeze(x, 0)
        x = permute(x, [0, 2, 1])
        x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
        out = permute(x, [0, 2, 1])
        if default_net().plugin_config.remove_input_padding:
            out = squeeze(out, 0)
        return out


class Attention(Module):
    def __init__(
        self,
        processor: AttnProcessor,
        dim: int,
        heads: int = 16,
        dim_head: int = 64,
        dropout: float = 0.0,
        context_dim: Optional[int] = None,  # if not None -> joint attention
        context_pre_only=None,
    ):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.processor = processor

        self.dim = dim  # hidden_size
        self.heads = heads
        self.inner_dim = dim_head * heads
        self.dropout = dropout
        self.attention_head_size = dim_head
        self.context_dim = context_dim
        self.context_pre_only = context_pre_only
        self.tp_size = 1
        self.num_attention_heads = heads // self.tp_size
        self.num_attention_kv_heads = heads // self.tp_size  # 8
        self.dtype = str_dtype_to_trt("float32")
        self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
        self.to_q = ColumnLinear(
            dim,
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )
        self.to_k = ColumnLinear(
            dim,
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )
        self.to_v = ColumnLinear(
            dim,
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )

        if self.context_dim is not None:
            self.to_k_c = Linear(context_dim, self.inner_dim)
            self.to_v_c = Linear(context_dim, self.inner_dim)
            if self.context_pre_only is not None:
                self.to_q_c = Linear(context_dim, self.inner_dim)

        self.to_out = RowLinear(
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            dim,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )

        if self.context_pre_only is not None and not self.context_pre_only:
            self.to_out_c = Linear(self.inner_dim, dim)

    def forward(
        self,
        x,  # noised input x
        rope_cos,
        rope_sin,
        input_lengths,
        c=None,  # context c
        scale=1.0,
        rope=None,
        c_rope=None,  # rotary position embedding for c
    ) -> torch.Tensor:
        if c is not None:
            return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
        else:
            return self.processor(
                self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
            )


def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
    shape_tensor = concat(
        [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
    )
    if default_net().plugin_config.remove_input_padding:
        assert tensor.ndim() == 2
        x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
        x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
        x1 = expand_dims(x1, 2)
        x2 = expand_dims(x2, 2)
        zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
        x2 = zero - x2
        x = concat([x2, x1], 2)
        out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
    else:
        assert tensor.ndim() == 3

        x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
        x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
        x1 = expand_dims(x1, 3)
        x2 = expand_dims(x2, 3)
        zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
        x2 = zero - x2
        x = concat([x2, x1], 3)
        out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))

    return out


def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
    if default_net().plugin_config.remove_input_padding:
        rot_dim = shape(rope_cos, -1)  # 64
        new_t_shape = concat([shape(x, 0), rot_dim])  # (-1, 64)
        x_ = slice(x, [0, 0], new_t_shape, [1, 1])
        end_dim = shape(x, -1) - shape(rope_cos, -1)
        new_t_unrotated_shape = concat([shape(x, 0), end_dim])  # (2, -1, 960)
        x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
        out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
    else:
        rot_dim = shape(rope_cos, 2)  # 64
        new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim])  # (2, -1, 64)
        x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
        end_dim = shape(x, 2) - shape(rope_cos, 2)
        new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim])  # (2, -1, 960)
        x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
        out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
    return out


class AttnProcessor:
    def __init__(self):
        pass

    def __call__(
        self,
        attn,
        x,  # noised input x
        rope_cos,
        rope_sin,
        input_lengths,
        scale=1.0,
        rope=None,
    ) -> torch.FloatTensor:
        query = attn.to_q(x)
        key = attn.to_k(x)
        value = attn.to_v(x)
        # k,v,q all (2,1226,1024)
        query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
        key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)

        # attention
        inner_dim = key.shape[-1]
        norm_factor = math.sqrt(attn.attention_head_size)
        q_scaling = 1.0 / norm_factor
        mask = None
        if not default_net().plugin_config.remove_input_padding:
            N = shape(x, 1)
            B = shape(x, 0)
            seq_len_2d = concat([1, N])
            max_position_embeddings = 4096
            # create position ids
            position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
            tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
            tmp_position_ids = expand(tmp_position_ids, concat([B, N]))  # BxL
            tmp_input_lengths = unsqueeze(input_lengths, 1)  # Bx1
            tmp_input_lengths = expand(tmp_input_lengths, concat([B, N]))  # BxL
            mask = tmp_position_ids < tmp_input_lengths  # BxL
            mask = mask.cast("int32")

        if default_net().plugin_config.bert_attention_plugin:
            qkv = concat([query, key, value], dim=-1)
            # TRT plugin mode
            assert input_lengths is not None
            if default_net().plugin_config.remove_input_padding:
                qkv = qkv.view(concat([-1, 3 * inner_dim]))
                max_input_length = constant(
                    np.zeros(
                        [
                            2048,
                        ],
                        dtype=np.int32,
                    )
                )
            else:
                max_input_length = None
            context = bert_attention(
                qkv,
                input_lengths,
                attn.num_attention_heads,
                attn.attention_head_size,
                q_scaling=q_scaling,
                max_input_length=max_input_length,
            )
        else:
            assert not default_net().plugin_config.remove_input_padding

            def transpose_for_scores(x):
                new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])

                y = x.view(new_x_shape)
                y = y.transpose(1, 2)
                return y

            def transpose_for_scores_k(x):
                new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])

                y = x.view(new_x_shape)
                y = y.permute([0, 2, 3, 1])
                return y

            query = transpose_for_scores(query)
            key = transpose_for_scores_k(key)
            value = transpose_for_scores(value)

            attention_scores = matmul(query, key, use_fp32_acc=False)

            if mask is not None:
                attention_mask = expand_mask(mask, shape(query, 2))
                attention_mask = cast(attention_mask, attention_scores.dtype)
                attention_scores = attention_scores + attention_mask

            attention_probs = softmax(attention_scores, dim=-1)

            context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
            context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
        context = attn.to_out(context)
        if mask is not None:
            mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
            mask = expand_dims_like(mask, context)
            mask = cast(mask, context.dtype)
            context = context * mask
        return context


# DiT Block
class DiTBlock(Module):
    def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
        super().__init__()

        self.attn_norm = AdaLayerNormZero(dim)
        self.attn = Attention(
            processor=AttnProcessor(),
            dim=dim,
            heads=heads,
            dim_head=dim_head,
            dropout=dropout,
        )

        self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)

    def forward(
        self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
    ):  # x: noised input, t: time embedding
        # pre-norm & modulation for attention input
        norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
        # attention
        # norm ----> (2,1226,1024)
        attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)

        # process attention output for input x
        if default_net().plugin_config.remove_input_padding:
            x = x + gate_msa * attn_output
        else:
            x = x + unsqueeze(gate_msa, 1) * attn_output
        ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
        if default_net().plugin_config.remove_input_padding:
            norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
        else:
            norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
            # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
        ff_output = self.ff(norm)
        if default_net().plugin_config.remove_input_padding:
            x = x + gate_mlp * ff_output
        else:
            x = x + unsqueeze(gate_mlp, 1) * ff_output

        return x


class TimestepEmbedding(Module):
    def __init__(self, dim, freq_embed_dim=256, dtype=None):
        super().__init__()
        # self.time_embed = SinusPositionEmbedding(freq_embed_dim)
        self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
        self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)

    def forward(self, timestep):
        t_freq = self.mlp1(timestep)
        t_freq = silu(t_freq)
        t_emb = self.mlp2(t_freq)
        return t_emb