asigalov61 commited on
Commit
838f8c4
·
verified ·
1 Parent(s): fd7b81e

Delete x_transformer_1_23_2.py

Browse files
Files changed (1) hide show
  1. x_transformer_1_23_2.py +0 -2481
x_transformer_1_23_2.py DELETED
@@ -1,2481 +0,0 @@
1
- #===================================================================================================================
2
- #
3
- # X Trasformer Module
4
- #
5
- # Partial x-transformers code With useful modifications
6
- #
7
- # Version 1.0
8
- #
9
- # Original source code courtesy of lucidrains
10
- # https://github.com/lucidrains/x-transformers
11
- #
12
- # Original source code retrieved on 10/10/2023
13
- #
14
- # Project Los Angeles
15
- # Tegridy Code 2023
16
-
17
- #===================================================================================================================
18
-
19
- # Critical dependencies
20
- #
21
- # !pip install torch
22
- # !pip install einops
23
-
24
- #===================================================================================================================
25
-
26
- from functools import partial
27
- from typing import Optional, Tuple
28
-
29
- import os
30
- os.environ['USE_FLASH_ATTENTION'] = '1'
31
-
32
- import torch
33
- from torch import nn, einsum, Tensor
34
- import torch.nn.functional as F
35
-
36
- # Flash attention
37
- from torch.nn.attention import SDPBackend, sdpa_kernel
38
- torch.backends.cuda.enable_flash_sdp(True)
39
-
40
- from collections import namedtuple
41
- from functools import wraps
42
- from packaging import version
43
- from dataclasses import dataclass
44
-
45
- from einops import rearrange, repeat
46
-
47
- # constants
48
-
49
- EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
50
-
51
- @dataclass
52
- class Intermediates:
53
- qk_similarities: Optional[Tensor] = None
54
- pre_softmax_attn: Optional[Tensor] = None
55
- post_softmax_attn: Optional[Tensor] = None
56
- cached_kv: Optional[Tuple[Tensor, Tensor]] = None
57
-
58
- def to_tuple(self):
59
- return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
60
-
61
- # helpers
62
-
63
- def exists(val):
64
- return val is not None
65
-
66
- def default(val, d):
67
- return val if exists(val) else d
68
-
69
- def compact(arr):
70
- return [*filter(exists, arr)]
71
-
72
- def once(fn):
73
- called = False
74
- @wraps(fn)
75
- def inner(x):
76
- nonlocal called
77
- if called:
78
- return
79
- called = True
80
- return fn(x)
81
- return inner
82
-
83
- print_once = once(print)
84
-
85
- # functions for creating causal mask
86
- # need a special one for onnx cpu (no support for .triu)
87
-
88
- def create_causal_mask(i, j, device):
89
- return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
90
-
91
- def onnx_create_causal_mask(i, j, device):
92
- r = torch.arange(i, device = device)
93
- causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
94
- causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
95
- return causal_mask
96
-
97
- # main class
98
-
99
- class Attend(nn.Module):
100
- def __init__(
101
- self,
102
- *,
103
- dropout = 0.,
104
- causal = False,
105
- heads = None,
106
- talking_heads = False,
107
- sparse_topk = None,
108
- scale = None,
109
- qk_norm = False,
110
- flash = False,
111
- add_zero_kv = False,
112
- onnxable = False
113
- ):
114
- super().__init__()
115
- self.scale = scale
116
- self.qk_norm = qk_norm
117
-
118
- self.causal = causal
119
- self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
120
-
121
- self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
122
-
123
- self.dropout = dropout
124
- self.attn_dropout = nn.Dropout(dropout)
125
-
126
- # talking heads
127
-
128
- assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
129
-
130
- self.talking_heads = talking_heads
131
- if talking_heads:
132
- self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
133
- self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
134
-
135
- # sparse topk
136
-
137
- assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
138
- self.sparse_topk = sparse_topk
139
-
140
- # add a key / value token composed of zeros
141
- # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
142
-
143
- self.add_zero_kv = add_zero_kv
144
-
145
- # flash attention
146
-
147
- self.flash = flash
148
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
149
-
150
- # determine efficient attention configs for cuda and cpu
151
-
152
- self.cpu_config = EfficientAttentionConfig(True, True, True)
153
- self.cuda_config = None
154
-
155
- if not torch.cuda.is_available() or not flash:
156
- return
157
-
158
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
159
-
160
- major, minor = device_properties.major, device_properties.minor
161
-
162
- if (major, minor) == (8, 0):
163
- print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
164
- self.cuda_config = EfficientAttentionConfig(True, False, False)
165
- elif (major, minor) == (9, 0):
166
- print_once('H100 GPU detected, using flash attention')
167
- self.cuda_config = EfficientAttentionConfig(True, False, False)
168
- else:
169
- print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
170
- self.cuda_config = EfficientAttentionConfig(False, True, True)
171
-
172
- def flash_attn(
173
- self,
174
- q, k, v,
175
- mask = None,
176
- attn_bias = None
177
- ):
178
- batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
179
-
180
- # Recommended for multi-query single-key-value attention by Tri Dao
181
- # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
182
-
183
- if k.ndim == 3:
184
- k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
185
-
186
- if v.ndim == 3:
187
- v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
188
-
189
- # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
190
-
191
- if self.qk_norm:
192
- default_scale = q.shape[-1] ** -0.5
193
- q = q * (self.scale / default_scale)
194
-
195
- # Check if mask exists and expand to compatible shape
196
- # The mask is B L, so it would have to be expanded to B H N L
197
-
198
- causal = self.causal
199
-
200
- # in the case of kv caching with one token (q_len == 1), just turn off causal masking
201
- # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
202
-
203
- if q_len == 1 and causal:
204
- causal = False
205
-
206
- # expand key padding mask
207
-
208
- if exists(mask):
209
- assert mask.ndim == 4
210
- mask = mask.expand(batch, heads, q_len, k_len)
211
-
212
- # handle kv cache - this should be bypassable in updated flash attention 2
213
-
214
- if k_len > q_len and causal:
215
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
216
- if not exists(mask):
217
- mask = ~causal_mask
218
- else:
219
- mask = mask & ~causal_mask
220
- causal = False
221
-
222
- # manually handle causal mask, if another mask was given
223
-
224
- row_is_entirely_masked = None
225
-
226
- if exists(mask) and causal:
227
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
228
- mask = mask & ~causal_mask
229
-
230
- # protect against an entire row being masked out
231
-
232
- row_is_entirely_masked = ~mask.any(dim = -1)
233
- mask[..., 0] = mask[..., 0] | row_is_entirely_masked
234
-
235
- causal = False
236
-
237
- # handle alibi positional bias
238
- # convert from bool to float
239
-
240
- if exists(attn_bias):
241
- attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
242
-
243
- # if mask given, the mask would already contain the causal mask from above logic
244
- # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
245
-
246
- mask_value = -torch.finfo(q.dtype).max
247
-
248
- if exists(mask):
249
- attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
250
- elif causal:
251
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
252
- attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
253
- causal = False
254
-
255
- # scaled_dot_product_attention handles attn_mask either as bool or additive bias
256
- # make it an additive bias here
257
-
258
- mask = attn_bias
259
-
260
- # Check if there is a compatible device for flash attention
261
-
262
- config = self.cuda_config if is_cuda else self.cpu_config
263
-
264
- # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
265
-
266
- # Legacy code...
267
- # with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
268
- # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
269
-
270
- # PyTorch 2.3-2.4 SDPA backend code...
271
- # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
272
- with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
273
-
274
- # New PyTorch 2.5 SDPA backend code:
275
- # with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
276
-
277
- out = F.scaled_dot_product_attention(
278
- q, k, v,
279
- attn_mask = mask,
280
- dropout_p = self.dropout if self.training else 0.,
281
- is_causal = causal
282
- )
283
-
284
- # for a row that is entirely masked out, should zero out the output of that row token
285
-
286
- if exists(row_is_entirely_masked):
287
- out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
288
-
289
- return out, Intermediates()
290
-
291
- def forward(
292
- self,
293
- q, k, v,
294
- mask = None,
295
- attn_bias = None,
296
- prev_attn = None
297
- ):
298
- """
299
- einstein notation
300
- b - batch
301
- h - heads
302
- n, i, j - sequence length (base sequence length, source, target)
303
- d - feature dimension
304
- """
305
-
306
- n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
307
-
308
- scale = default(self.scale, q.shape[-1] ** -0.5)
309
-
310
- causal = self.causal
311
-
312
- # handle kv cached decoding
313
-
314
- if n == 1 and causal:
315
- causal = False
316
-
317
- # handle grouped multi-query attention
318
-
319
- if kv_heads == 1:
320
- k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
321
- elif kv_heads < heads:
322
- k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
323
-
324
- # handle zero kv, as means for allowing network to attend to nothing
325
-
326
- if self.add_zero_kv:
327
- k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
328
-
329
- if exists(mask):
330
- mask = F.pad(mask, (1, 0), value = True)
331
-
332
- if exists(attn_bias):
333
- attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
334
-
335
- if self.flash:
336
- assert not exists(prev_attn), 'residual attention not compatible with flash attention'
337
- return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
338
-
339
- kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
340
-
341
- dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
342
-
343
- if exists(prev_attn):
344
- dots = dots + prev_attn
345
-
346
- qk_similarities = dots.clone()
347
-
348
- if self.talking_heads:
349
- dots = self.pre_softmax_talking_heads(dots)
350
-
351
- if exists(attn_bias):
352
- dots = dots + attn_bias
353
-
354
- i, j, dtype = *dots.shape[-2:], dots.dtype
355
-
356
- mask_value = -torch.finfo(dots.dtype).max
357
-
358
- if exists(self.sparse_topk) and self.sparse_topk < j:
359
- top_values, _ = dots.topk(self.sparse_topk, dim = -1)
360
- sparse_topk_mask = dots < top_values[..., -1:]
361
- mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
362
-
363
- if exists(mask):
364
- dots = dots.masked_fill(~mask, mask_value)
365
-
366
- if causal:
367
- causal_mask = self.create_causal_mask(i, j, device = device)
368
- dots = dots.masked_fill(causal_mask, mask_value)
369
-
370
- pre_softmax_attn = dots.clone()
371
-
372
- attn = self.attn_fn(dots, dim = -1)
373
- attn = attn.type(dtype)
374
-
375
- post_softmax_attn = attn.clone()
376
-
377
- attn = self.attn_dropout(attn)
378
-
379
- if self.talking_heads:
380
- attn = self.post_softmax_talking_heads(attn)
381
-
382
- out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
383
-
384
- intermediates = Intermediates(
385
- qk_similarities = qk_similarities,
386
- pre_softmax_attn = pre_softmax_attn,
387
- post_softmax_attn = post_softmax_attn
388
- )
389
-
390
- return out, intermediates
391
-
392
- #===================================================================================================================
393
-
394
- from math import ceil, log
395
- from typing import Optional, Union, Tuple, Callable
396
-
397
- import torch
398
- from torch import nn, Tensor
399
- from torch.nn import Module
400
- import torch.nn.functional as F
401
-
402
- from einops import rearrange, pack, unpack
403
-
404
- def exists(val):
405
- return val is not None
406
-
407
- def default(val, d):
408
- return val if exists(val) else d
409
-
410
- def identity(t, *args, **kwargs):
411
- return t
412
-
413
- def cast_tuple(t, length = 1):
414
- return t if isinstance(t, tuple) else (t,) * length
415
-
416
- def eval_decorator(fn):
417
- def inner(self, *args, **kwargs):
418
- was_training = self.training
419
- self.eval()
420
- out = fn(self, *args, **kwargs)
421
- self.train(was_training)
422
- return out
423
- return inner
424
-
425
- # for variable lengthed prefixes
426
-
427
- def align_right(t, lens, pad_id = 0):
428
- batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
429
-
430
- assert lens.ndim == 1 and lens.shape[0] == batch
431
- assert lens.amax() <= seq_len
432
-
433
- pad_lens = seq_len - lens
434
- max_pad_len = pad_lens.amax()
435
-
436
- batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
437
- prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
438
-
439
- t = F.pad(t, (max_pad_len, 0), value = 0)
440
- offset = max_pad_len - pad_lens
441
-
442
- aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
443
- return aligned
444
-
445
- # nucleus
446
-
447
- def top_p(logits, thres = 0.9):
448
- sorted_logits, sorted_indices = torch.sort(logits, descending = True)
449
- cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
450
-
451
- sorted_indices_to_remove = cum_probs > thres
452
- sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
453
-
454
- sorted_logits[sorted_indices_to_remove] = float('-inf')
455
- return sorted_logits.scatter(1, sorted_indices, sorted_logits)
456
-
457
- # topk
458
-
459
- def top_k(logits, frac_num_tokens = 0.1, k = None):
460
- num_tokens = logits.shape[-1]
461
-
462
- k = default(k, ceil(frac_num_tokens * num_tokens))
463
- k = min(k, num_tokens)
464
-
465
- val, ind = torch.topk(logits, k)
466
- probs = torch.full_like(logits, float('-inf'))
467
- probs.scatter_(1, ind, val)
468
- return probs
469
-
470
- # top_a
471
-
472
- def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
473
- probs = F.softmax(logits, dim = -1)
474
- max_probs = torch.amax(probs, dim = -1, keepdim = True)
475
- limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
476
- return torch.where(probs < limit, float('-inf'), logits)
477
-
478
- # contrastive decoding function
479
-
480
- def contrastive_decode_fn(
481
- expert_logits,
482
- amateur_logits,
483
- alpha = 0.1,
484
- beta = 0.5
485
- ):
486
- """
487
- Appendix A Algorithm 2
488
- https://arxiv.org/abs/2309.09117
489
- """
490
-
491
- cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
492
- diffs = (1 + beta) * expert_logits - beta * amateur_logits
493
- contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
494
- return contrastive_decode_logits
495
-
496
- # autoregressive wrapper class
497
-
498
- class AutoregressiveWrapper(Module):
499
- def __init__(
500
- self,
501
- net,
502
- ignore_index = -100,
503
- pad_value = 0,
504
- mask_prob = 0.,
505
- add_attn_z_loss = False,
506
- return_cache=False
507
- ):
508
- super().__init__()
509
- self.pad_value = pad_value
510
- self.ignore_index = ignore_index
511
-
512
- self.net = net
513
- self.max_seq_len = net.max_seq_len
514
-
515
- # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
516
- assert mask_prob < 1.
517
- self.mask_prob = mask_prob
518
-
519
- # whether to add router z-loss
520
- self.add_attn_z_loss = add_attn_z_loss
521
- self.return_cache = return_cache
522
-
523
- @torch.inference_mode()
524
- @eval_decorator
525
- def generate(
526
- self,
527
- prompts,
528
- seq_len,
529
- eos_token = None,
530
- temperature = 1.,
531
- prompt_lens: Optional[Tensor] = None,
532
- filter_logits_fn: Callable = top_k,
533
- restrict_to_max_seq_len = True,
534
- amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
535
- filter_kwargs: dict = dict(),
536
- contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
537
- beta = 0.5,
538
- alpha = 0.1
539
- ),
540
- cache_kv = True,
541
- verbose=True,
542
- return_prime=False,
543
- **kwargs
544
- ):
545
- max_seq_len, device = self.max_seq_len, prompts.device
546
-
547
- prompts, ps = pack([prompts], '* n')
548
-
549
- b, t = prompts.shape
550
-
551
- # handle variable lengthed prompts (prefixes)
552
-
553
- seq_start_pos = None
554
- if exists(prompt_lens):
555
- prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
556
- seq_start_pos = t - prompt_lens
557
-
558
- # output from which sampled tokens appended to
559
-
560
- out = prompts
561
-
562
- if verbose:
563
- print("Generating sequence of max length:", seq_len)
564
-
565
- # kv caches
566
-
567
- cache = None
568
-
569
- # if doing contrastive decoding, turn off filter automatically
570
-
571
- if exists(amateur_model):
572
- amateur_model = cast_tuple(amateur_model)
573
- contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
574
-
575
- assert len(amateur_model) == len(contrastive_decode_kwargs)
576
-
577
- amateur_caches = [None] * len(amateur_model)
578
- filter_logits_fn = identity
579
-
580
- for i, module in enumerate(amateur_model):
581
- if isinstance(module, AutoregressiveWrapper):
582
- amateur_model[i] = module.net
583
-
584
- module.eval()
585
-
586
- # sampling up to seq_len
587
-
588
- for sl in range(seq_len):
589
-
590
- if restrict_to_max_seq_len:
591
- x = out[:, -max_seq_len:]
592
-
593
- if exists(cache):
594
- for inter in cache.attn_intermediates:
595
- inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
596
-
597
- logits, new_cache = self.net(
598
- x,
599
- return_intermediates = True,
600
- cache = cache,
601
- seq_start_pos = seq_start_pos,
602
- **kwargs
603
- )
604
-
605
- if cache_kv and self.net.can_cache_kv:
606
- cache = new_cache
607
-
608
- logits = logits[:, -1]
609
-
610
- # handle contrastive decoding, Li et al.
611
- # https://arxiv.org/abs/2210.15097
612
-
613
- if exists(amateur_model):
614
- for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
615
- amateur_logits, next_amateur_cache = amateur(
616
- x,
617
- return_intermediates = True,
618
- cache = amateur_cache,
619
- seq_start_pos = seq_start_pos,
620
- **kwargs
621
- )
622
-
623
- amateur_logits = amateur_logits[:, -1]
624
-
625
- assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
626
- logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
627
-
628
- if cache_kv and amateur.can_cache_kv:
629
- amateur_caches[i] = next_amateur_cache
630
-
631
- # filter by top_k, top_p (nucleus), top_a, or custom
632
-
633
- filtered_logits = filter_logits_fn(logits, **filter_kwargs)
634
-
635
- probs = F.softmax(filtered_logits / temperature, dim=-1)
636
-
637
- sample = torch.multinomial(probs, 1)
638
-
639
- out = torch.cat((out, sample), dim=-1)
640
-
641
- if verbose:
642
- if sl % 32 == 0:
643
- print(sl, '/', seq_len)
644
-
645
- if exists(eos_token):
646
- is_eos_tokens = (out == eos_token)
647
-
648
- if is_eos_tokens.any(dim = -1).all():
649
- # mask out everything after the eos tokens
650
- shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
651
- mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
652
- out = out.masked_fill(mask, self.pad_value)
653
-
654
- if verbose:
655
- print('Model called the end of sequence at:', sl, '/', seq_len)
656
-
657
- break
658
-
659
- if return_prime:
660
- return out[:, :]
661
-
662
- else:
663
- return out[:, t:]
664
-
665
- # out, = unpack(out, ps, '* n')
666
-
667
- # return out
668
-
669
- def compute_accuracy(self, logits, labels):
670
- out = torch.argmax(logits, dim=-1)
671
- out = out.flatten()
672
- labels = labels.flatten()
673
-
674
- mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
675
- out = out[mask]
676
- labels = labels[mask]
677
-
678
- num_right = (out == labels)
679
- num_right = torch.sum(num_right).type(torch.float32)
680
-
681
- acc = num_right / len(labels)
682
- return acc
683
-
684
- def forward(self, x, **kwargs):
685
- seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
686
-
687
- inp, target = x[:, :-1], x[:, 1:]
688
- inp = torch.where(inp == ignore_index, self.pad_value, inp)
689
-
690
- if self.mask_prob > 0.:
691
- rand = torch.randn(inp.shape, device = x.device)
692
- rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
693
- num_mask = min(int(seq * self.mask_prob), seq - 1)
694
- indices = rand.topk(num_mask, dim = -1).indices
695
- mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
696
- kwargs.update(self_attn_kv_mask = mask)
697
-
698
- logits, cache = self.net(
699
- inp,
700
- return_intermediates = True,
701
- return_attn_z_loss = add_attn_z_loss,
702
- **kwargs
703
- )
704
-
705
- acc = self.compute_accuracy(logits, target)
706
-
707
- loss = F.cross_entropy(
708
- rearrange(logits, 'b n c -> b c n'),
709
- target,
710
- ignore_index = ignore_index
711
- )
712
-
713
- if add_attn_z_loss:
714
- loss = loss + cache.attn_z_loss
715
-
716
- if self.return_cache:
717
- return loss, acc, cache
718
-
719
- else:
720
- return loss, acc
721
-
722
- #===============================================================================
723
-
724
- import math
725
- from random import random
726
-
727
- import torch
728
- from torch import nn, einsum, Tensor
729
- import torch.nn.functional as F
730
-
731
- from functools import partial, wraps
732
- from inspect import isfunction
733
- from collections import namedtuple
734
- from dataclasses import dataclass
735
- from typing import List, Callable, Optional
736
-
737
- from einops import rearrange, repeat, reduce, pack, unpack
738
- from einops.layers.torch import Rearrange
739
-
740
- # constants
741
-
742
- DEFAULT_DIM_HEAD = 64
743
-
744
- @dataclass
745
- class LayerIntermediates:
746
- hiddens: Optional[List[Tensor]] = None
747
- attn_intermediates: Optional[List[Intermediates]] = None
748
- layer_hiddens: Optional[List[Tensor]] = None
749
- attn_z_loss: Optional[Tensor] = None
750
- mems: Optional[Tensor] = None
751
-
752
- # helpers
753
-
754
- def exists(val):
755
- return val is not None
756
-
757
- def default(val, d):
758
- if exists(val):
759
- return val
760
- return d() if isfunction(d) else d
761
-
762
- def cast_tuple(val, depth):
763
- return val if isinstance(val, tuple) else (val,) * depth
764
-
765
- def divisible_by(num, den):
766
- return (num % den) == 0
767
-
768
- def maybe(fn):
769
- @wraps(fn)
770
- def inner(x, *args, **kwargs):
771
- if not exists(x):
772
- return x
773
- return fn(x, *args, **kwargs)
774
- return inner
775
-
776
- class always():
777
- def __init__(self, val):
778
- self.val = val
779
- def __call__(self, *args, **kwargs):
780
- return self.val
781
-
782
- class not_equals():
783
- def __init__(self, val):
784
- self.val = val
785
- def __call__(self, x, *args, **kwargs):
786
- return x != self.val
787
-
788
- class equals():
789
- def __init__(self, val):
790
- self.val = val
791
- def __call__(self, x, *args, **kwargs):
792
- return x == self.val
793
-
794
- def Sequential(*modules):
795
- return nn.Sequential(*filter(exists, modules))
796
-
797
- # tensor helpers
798
-
799
- def max_neg_value(tensor):
800
- return -torch.finfo(tensor.dtype).max
801
-
802
- def l2norm(t, groups = 1):
803
- t = rearrange(t, '... (g d) -> ... g d', g = groups)
804
- t = F.normalize(t, p = 2, dim = -1)
805
- return rearrange(t, '... g d -> ... (g d)')
806
-
807
- def pad_at_dim(t, pad, dim = -1, value = 0.):
808
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
809
- zeros = ((0, 0) * dims_from_right)
810
- return F.pad(t, (*zeros, *pad), value = value)
811
-
812
- def or_reduce(masks):
813
- head, *body = masks
814
- for rest in body:
815
- head = head | rest
816
- return head
817
-
818
- # auxiliary loss helpers
819
-
820
- def calc_z_loss(
821
- pre_softmax_attns: List[Tensor],
822
- mask = None,
823
- weight = 1.
824
- ):
825
- # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
826
- # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
827
- # also used in PaLM as one of the measures
828
-
829
- lse = 0.
830
-
831
- for attn in pre_softmax_attns:
832
- lse = lse + attn.logsumexp(dim = -1)
833
-
834
- loss = torch.square(lse)
835
- loss = reduce(loss, 'b h n -> b n', 'sum')
836
-
837
- if not exists(mask):
838
- return loss.mean() * weight
839
-
840
- loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
841
- return loss * weight
842
-
843
- # init helpers
844
-
845
- def init_zero_(layer):
846
- nn.init.constant_(layer.weight, 0.)
847
- if exists(layer.bias):
848
- nn.init.constant_(layer.bias, 0.)
849
-
850
- # keyword argument helpers
851
-
852
- def pick_and_pop(keys, d):
853
- values = list(map(lambda key: d.pop(key), keys))
854
- return dict(zip(keys, values))
855
-
856
- def group_dict_by_key(cond, d):
857
- return_val = [dict(),dict()]
858
- for key in d.keys():
859
- match = bool(cond(key))
860
- ind = int(not match)
861
- return_val[ind][key] = d[key]
862
- return (*return_val,)
863
-
864
- def string_begins_with(prefix, str):
865
- return str.startswith(prefix)
866
-
867
- def group_by_key_prefix(prefix, d):
868
- return group_dict_by_key(partial(string_begins_with, prefix), d)
869
-
870
- def groupby_prefix_and_trim(prefix, d):
871
- kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
872
- kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
873
- return kwargs_without_prefix, kwargs
874
-
875
- # structured dropout, more effective than traditional attention dropouts
876
-
877
- def dropout_seq(seq, mask, dropout):
878
- b, n, *_, device = *seq.shape, seq.device
879
- logits = torch.randn(b, n, device = device)
880
-
881
- if exists(mask):
882
- mask_value = max_neg_value(logits)
883
- logits = logits.masked_fill(~mask, mask_value)
884
-
885
- keep_prob = 1. - dropout
886
- num_keep = max(1, int(keep_prob * n))
887
- keep_indices = logits.topk(num_keep, dim = 1).indices
888
-
889
- batch_indices = torch.arange(b, device = device)
890
- batch_indices = rearrange(batch_indices, 'b -> b 1')
891
-
892
- seq = seq[batch_indices, keep_indices]
893
-
894
- if exists(mask):
895
- seq_counts = mask.sum(dim = -1)
896
- seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
897
- keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
898
-
899
- mask = mask[batch_indices, keep_indices] & keep_mask
900
-
901
- return seq, mask
902
-
903
- # activations
904
-
905
- class ReluSquared(nn.Module):
906
- def forward(self, x):
907
- return F.relu(x) ** 2
908
-
909
- # embedding
910
-
911
- class TokenEmbedding(nn.Module):
912
- def __init__(self, dim, num_tokens, l2norm_embed = False):
913
- super().__init__()
914
- self.l2norm_embed = l2norm_embed
915
- self.emb = nn.Embedding(num_tokens, dim)
916
-
917
- def forward(self, x):
918
- token_emb = self.emb(x)
919
- return l2norm(token_emb) if self.l2norm_embed else token_emb
920
-
921
- # positional embeddings
922
-
923
- class AbsolutePositionalEmbedding(nn.Module):
924
- def __init__(self, dim, max_seq_len, l2norm_embed = False):
925
- super().__init__()
926
- self.scale = dim ** -0.5 if not l2norm_embed else 1.
927
- self.max_seq_len = max_seq_len
928
- self.l2norm_embed = l2norm_embed
929
- self.emb = nn.Embedding(max_seq_len, dim)
930
-
931
- def forward(self, x, pos = None, seq_start_pos = None):
932
- seq_len, device = x.shape[1], x.device
933
- assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
934
-
935
- if not exists(pos):
936
- pos = torch.arange(seq_len, device = device)
937
-
938
- if exists(seq_start_pos):
939
- pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
940
-
941
- pos_emb = self.emb(pos)
942
- pos_emb = pos_emb * self.scale
943
- return l2norm(pos_emb) if self.l2norm_embed else pos_emb
944
-
945
- class ScaledSinusoidalEmbedding(nn.Module):
946
- def __init__(self, dim, theta = 10000):
947
- super().__init__()
948
- assert divisible_by(dim, 2)
949
- self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
950
-
951
- half_dim = dim // 2
952
- freq_seq = torch.arange(half_dim).float() / half_dim
953
- inv_freq = theta ** -freq_seq
954
- self.register_buffer('inv_freq', inv_freq, persistent = False)
955
-
956
- def forward(self, x, pos = None, seq_start_pos = None):
957
- seq_len, device = x.shape[1], x.device
958
-
959
- if not exists(pos):
960
- pos = torch.arange(seq_len, device = device)
961
-
962
- if exists(seq_start_pos):
963
- pos = pos - seq_start_pos[..., None]
964
-
965
- emb = einsum('i, j -> i j', pos, self.inv_freq)
966
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
967
- return emb * self.scale
968
-
969
- class RelativePositionBias(nn.Module):
970
- def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
971
- super().__init__()
972
- self.scale = scale
973
- self.causal = causal
974
- self.num_buckets = num_buckets
975
- self.max_distance = max_distance
976
- self.relative_attention_bias = nn.Embedding(num_buckets, heads)
977
-
978
- @staticmethod
979
- def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
980
- ret = 0
981
- n = -relative_position
982
- if not causal:
983
- num_buckets //= 2
984
- ret += (n < 0).long() * num_buckets
985
- n = torch.abs(n)
986
- else:
987
- n = torch.max(n, torch.zeros_like(n))
988
-
989
- max_exact = num_buckets // 2
990
- is_small = n < max_exact
991
-
992
- val_if_large = max_exact + (
993
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
994
- ).long()
995
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
996
-
997
- ret += torch.where(is_small, n, val_if_large)
998
- return ret
999
-
1000
- @property
1001
- def device(self):
1002
- return next(self.parameters()).device
1003
-
1004
- def forward(self, i, j):
1005
- device = self.device
1006
- q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
1007
- k_pos = torch.arange(j, dtype = torch.long, device = device)
1008
- rel_pos = k_pos[None, :] - q_pos[:, None]
1009
- rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
1010
- values = self.relative_attention_bias(rp_bucket)
1011
- bias = rearrange(values, 'i j h -> h i j')
1012
- return bias * self.scale
1013
-
1014
- class DynamicPositionBias(nn.Module):
1015
- def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
1016
- super().__init__()
1017
- assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
1018
- self.log_distance = log_distance
1019
-
1020
- self.mlp = nn.ModuleList([])
1021
-
1022
- self.mlp.append(Sequential(
1023
- nn.Linear(1, dim),
1024
- nn.LayerNorm(dim) if norm else None,
1025
- nn.SiLU()
1026
- ))
1027
-
1028
- for _ in range(depth - 1):
1029
- self.mlp.append(Sequential(
1030
- nn.Linear(dim, dim),
1031
- nn.LayerNorm(dim) if norm else None,
1032
- nn.SiLU()
1033
- ))
1034
-
1035
- self.mlp.append(nn.Linear(dim, heads))
1036
-
1037
- @property
1038
- def device(self):
1039
- return next(self.parameters()).device
1040
-
1041
- def forward(self, i, j):
1042
- assert i == j
1043
- n, device = j, self.device
1044
-
1045
- # get the (n x n) matrix of distances
1046
- seq_arange = torch.arange(n, device = device)
1047
- context_arange = torch.arange(n, device = device)
1048
- indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
1049
- indices += (n - 1)
1050
-
1051
- # input to continuous positions MLP
1052
- pos = torch.arange(-n + 1, n, device = device).float()
1053
- pos = rearrange(pos, '... -> ... 1')
1054
-
1055
- if self.log_distance:
1056
- pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
1057
-
1058
- for layer in self.mlp:
1059
- pos = layer(pos)
1060
-
1061
- # get position biases
1062
- bias = pos[indices]
1063
- bias = rearrange(bias, 'i j h -> h i j')
1064
- return bias
1065
-
1066
- class AlibiPositionalBias(nn.Module):
1067
- def __init__(self, heads, total_heads, **kwargs):
1068
- super().__init__()
1069
- self.heads = heads
1070
- self.total_heads = total_heads
1071
-
1072
- slopes = Tensor(self._get_slopes(heads))
1073
- slopes = rearrange(slopes, 'h -> h 1 1')
1074
- self.register_buffer('slopes', slopes, persistent = False)
1075
- self.register_buffer('bias', None, persistent = False)
1076
-
1077
- def get_bias(self, i, j, device):
1078
- i_arange = torch.arange(j - i, j, device = device)
1079
- j_arange = torch.arange(j, device = device)
1080
- bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
1081
- return bias
1082
-
1083
- @staticmethod
1084
- def _get_slopes(heads):
1085
- def get_slopes_power_of_2(n):
1086
- start = (2**(-2**-(math.log2(n)-3)))
1087
- ratio = start
1088
- return [start*ratio**i for i in range(n)]
1089
-
1090
- if math.log2(heads).is_integer():
1091
- return get_slopes_power_of_2(heads)
1092
-
1093
- closest_power_of_2 = 2 ** math.floor(math.log2(heads))
1094
- return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
1095
-
1096
- @property
1097
- def device(self):
1098
- return next(self.buffers()).device
1099
-
1100
- def forward(self, i, j):
1101
- h, device = self.total_heads, self.device
1102
-
1103
- if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
1104
- return self.bias[..., -i:, -j:]
1105
-
1106
- bias = self.get_bias(i, j, device)
1107
- bias = bias * self.slopes
1108
-
1109
- num_heads_unalibied = h - bias.shape[0]
1110
- bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
1111
- self.register_buffer('bias', bias, persistent = False)
1112
-
1113
- return self.bias
1114
-
1115
- class RotaryEmbedding(nn.Module):
1116
- def __init__(
1117
- self,
1118
- dim,
1119
- use_xpos = False,
1120
- scale_base = 512,
1121
- interpolation_factor = 1.,
1122
- base = 10000,
1123
- base_rescale_factor = 1.
1124
- ):
1125
- super().__init__()
1126
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
1127
- # has some connection to NTK literature
1128
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
1129
- base *= base_rescale_factor ** (dim / (dim - 2))
1130
-
1131
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
1132
- self.register_buffer('inv_freq', inv_freq)
1133
-
1134
- assert interpolation_factor >= 1.
1135
- self.interpolation_factor = interpolation_factor
1136
-
1137
- if not use_xpos:
1138
- self.register_buffer('scale', None)
1139
- return
1140
-
1141
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
1142
-
1143
- self.scale_base = scale_base
1144
- self.register_buffer('scale', scale)
1145
-
1146
- def forward(self, seq_len):
1147
- device = self.inv_freq.device
1148
- t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
1149
-
1150
- t = t / self.interpolation_factor
1151
-
1152
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
1153
- freqs = torch.cat((freqs, freqs), dim = -1)
1154
-
1155
- if not exists(self.scale):
1156
- return freqs, 1.
1157
-
1158
- power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
1159
- scale = self.scale ** rearrange(power, 'n -> n 1')
1160
- scale = torch.cat((scale, scale), dim = -1)
1161
-
1162
- return freqs, scale
1163
-
1164
-
1165
- def rotate_half(x):
1166
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
1167
- x1, x2 = x.unbind(dim = -2)
1168
- return torch.cat((-x2, x1), dim = -1)
1169
-
1170
- def apply_rotary_pos_emb(t, freqs, scale = 1):
1171
- rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
1172
- freqs = freqs[-seq_len:, :]
1173
-
1174
- if t.ndim == 4 and freqs.ndim == 3:
1175
- freqs = rearrange(freqs, 'b n d -> b 1 n d')
1176
-
1177
- # partial rotary embeddings, Wang et al. GPT-J
1178
- t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
1179
- t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
1180
- return torch.cat((t, t_unrotated), dim = -1)
1181
-
1182
- # norms
1183
-
1184
- class Scale(nn.Module):
1185
- def __init__(self, value, fn):
1186
- super().__init__()
1187
- self.value = value
1188
- self.fn = fn
1189
-
1190
- def forward(self, x, **kwargs):
1191
- out = self.fn(x, **kwargs)
1192
- scale_fn = lambda t: t * self.value
1193
-
1194
- if not isinstance(out, tuple):
1195
- return scale_fn(out)
1196
-
1197
- return (scale_fn(out[0]), *out[1:])
1198
-
1199
- class ScaleNorm(nn.Module):
1200
- def __init__(self, dim, eps = 1e-5):
1201
- super().__init__()
1202
- self.eps = eps
1203
- self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
1204
-
1205
- def forward(self, x):
1206
- norm = torch.norm(x, dim = -1, keepdim = True)
1207
- return x / norm.clamp(min = self.eps) * self.g
1208
-
1209
- class RMSNorm(nn.Module):
1210
- def __init__(self, dim):
1211
- super().__init__()
1212
- self.scale = dim ** 0.5
1213
- self.g = nn.Parameter(torch.ones(dim))
1214
-
1215
- def forward(self, x):
1216
- return F.normalize(x, dim = -1) * self.scale * self.g
1217
-
1218
- class SimpleRMSNorm(nn.Module):
1219
- def __init__(self, dim):
1220
- super().__init__()
1221
- self.scale = dim ** 0.5
1222
-
1223
- def forward(self, x):
1224
- return F.normalize(x, dim = -1) * self.scale
1225
-
1226
- # residual and residual gates
1227
-
1228
- class Residual(nn.Module):
1229
- def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
1230
- super().__init__()
1231
- self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1232
- self.scale_residual_constant = scale_residual_constant
1233
-
1234
- def forward(self, x, residual):
1235
- if exists(self.residual_scale):
1236
- residual = residual * self.residual_scale
1237
-
1238
- if self.scale_residual_constant != 1:
1239
- residual = residual * self.scale_residual_constant
1240
-
1241
- return x + residual
1242
-
1243
- class GRUGating(nn.Module):
1244
- def __init__(self, dim, scale_residual = False, **kwargs):
1245
- super().__init__()
1246
- self.gru = nn.GRUCell(dim, dim)
1247
- self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1248
-
1249
- def forward(self, x, residual):
1250
- if exists(self.residual_scale):
1251
- residual = residual * self.residual_scale
1252
-
1253
- gated_output = self.gru(
1254
- rearrange(x, 'b n d -> (b n) d'),
1255
- rearrange(residual, 'b n d -> (b n) d')
1256
- )
1257
-
1258
- return gated_output.reshape_as(x)
1259
-
1260
- # token shifting
1261
-
1262
- def shift(t, amount, mask = None):
1263
- if amount == 0:
1264
- return t
1265
- else:
1266
- amount = min(amount, t.shape[1])
1267
-
1268
- if exists(mask):
1269
- t = t.masked_fill(~mask[..., None], 0.)
1270
-
1271
- return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
1272
-
1273
- class ShiftTokens(nn.Module):
1274
- def __init__(self, shifts, fn):
1275
- super().__init__()
1276
- self.fn = fn
1277
- self.shifts = tuple(shifts)
1278
-
1279
- def forward(self, x, **kwargs):
1280
- mask = kwargs.get('mask', None)
1281
- shifts = self.shifts
1282
- segments = len(shifts)
1283
- feats_per_shift = x.shape[-1] // segments
1284
- splitted = x.split(feats_per_shift, dim = -1)
1285
- segments_to_shift, rest = splitted[:segments], splitted[segments:]
1286
- segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
1287
- x = torch.cat((*segments_to_shift, *rest), dim = -1)
1288
- return self.fn(x, **kwargs)
1289
-
1290
- # feedforward
1291
-
1292
- class GLU(nn.Module):
1293
- def __init__(
1294
- self,
1295
- dim_in,
1296
- dim_out,
1297
- activation: Callable,
1298
- mult_bias = False
1299
- ):
1300
- super().__init__()
1301
- self.act = activation
1302
- self.proj = nn.Linear(dim_in, dim_out * 2)
1303
- self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
1304
-
1305
- def forward(self, x):
1306
- x, gate = self.proj(x).chunk(2, dim = -1)
1307
- return x * self.act(gate) * self.mult_bias
1308
-
1309
- class FeedForward(nn.Module):
1310
- def __init__(
1311
- self,
1312
- dim,
1313
- dim_out = None,
1314
- mult = 4,
1315
- glu = False,
1316
- glu_mult_bias = False,
1317
- swish = False,
1318
- relu_squared = False,
1319
- post_act_ln = False,
1320
- dropout = 0.,
1321
- no_bias = False,
1322
- zero_init_output = False
1323
- ):
1324
- super().__init__()
1325
- inner_dim = int(dim * mult)
1326
- dim_out = default(dim_out, dim)
1327
-
1328
- if relu_squared:
1329
- activation = ReluSquared()
1330
- elif swish:
1331
- activation = nn.SiLU()
1332
- else:
1333
- activation = nn.GELU()
1334
-
1335
- if glu:
1336
- project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1337
- else:
1338
- project_in = nn.Sequential(
1339
- nn.Linear(dim, inner_dim, bias = not no_bias),
1340
- activation
1341
- )
1342
-
1343
- self.ff = Sequential(
1344
- project_in,
1345
- nn.LayerNorm(inner_dim) if post_act_ln else None,
1346
- nn.Dropout(dropout),
1347
- nn.Linear(inner_dim, dim_out, bias = not no_bias)
1348
- )
1349
-
1350
- # init last linear layer to 0
1351
- if zero_init_output:
1352
- init_zero_(self.ff[-1])
1353
-
1354
- def forward(self, x):
1355
- return self.ff(x)
1356
-
1357
- # attention. it is all we need
1358
-
1359
- class Attention(nn.Module):
1360
- def __init__(
1361
- self,
1362
- dim,
1363
- dim_head = DEFAULT_DIM_HEAD,
1364
- heads = 8,
1365
- causal = False,
1366
- flash = False,
1367
- talking_heads = False,
1368
- head_scale = False,
1369
- sparse_topk = None,
1370
- num_mem_kv = 0,
1371
- dropout = 0.,
1372
- on_attn = False,
1373
- gate_value_heads = False,
1374
- gate_values = False,
1375
- zero_init_output = False,
1376
- max_attend_past = None,
1377
- qk_norm = False,
1378
- qk_norm_groups = 1,
1379
- qk_norm_scale = 10,
1380
- qk_norm_dim_scale = False,
1381
- one_kv_head = False,
1382
- kv_heads = None,
1383
- shared_kv = False,
1384
- value_dim_head = None,
1385
- tensor_product = False, # https://arxiv.org/abs/2208.06061
1386
- add_zero_kv = False, # same as add_zero_attn in pytorch
1387
- rotary_embed_values = False,
1388
- onnxable = False
1389
- ):
1390
- super().__init__()
1391
- self.scale = dim_head ** -0.5
1392
-
1393
- self.heads = heads
1394
- self.causal = causal
1395
- self.max_attend_past = max_attend_past
1396
-
1397
- assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
1398
-
1399
- value_dim_head = default(value_dim_head, dim_head)
1400
- kv_heads = default(kv_heads, heads)
1401
-
1402
- kv_heads = 1 if one_kv_head else kv_heads
1403
- assert divisible_by(heads, kv_heads)
1404
-
1405
- self.kv_heads = kv_heads
1406
-
1407
- q_dim = dim_head * heads
1408
- k_dim = dim_head * kv_heads
1409
- v_dim = value_dim_head * kv_heads
1410
- out_dim = value_dim_head * heads
1411
-
1412
- self.to_q = nn.Linear(dim, q_dim, bias = False)
1413
- self.to_k = nn.Linear(dim, k_dim, bias = False)
1414
-
1415
- # shared key / values, for further memory savings during inference
1416
- assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1417
- self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
1418
-
1419
- # relations projection from tp-attention
1420
- self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
1421
-
1422
- # add GLU gating for aggregated values, from alphafold2
1423
- self.to_v_gate = None
1424
- if gate_values:
1425
- self.to_v_gate = nn.Linear(dim, out_dim)
1426
- nn.init.constant_(self.to_v_gate.weight, 0)
1427
- nn.init.constant_(self.to_v_gate.bias, 10)
1428
-
1429
- # add per head gating of the output values, from 'Attend to nothing' paper
1430
- self.to_v_head_gate = None
1431
- if gate_value_heads:
1432
- self.to_v_head_gate = nn.Linear(dim, heads)
1433
- nn.init.constant_(self.to_v_head_gate.weight, 0)
1434
- nn.init.constant_(self.to_v_head_gate.bias, 10)
1435
-
1436
- # cosine sim attention
1437
- self.qk_norm = qk_norm
1438
- self.qk_norm_groups = qk_norm_groups
1439
- self.qk_norm_scale = qk_norm_scale
1440
-
1441
- # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
1442
- self.qk_norm_dim_scale = qk_norm_dim_scale
1443
-
1444
- self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1445
- if qk_norm and qk_norm_dim_scale:
1446
- self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1447
- self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1448
-
1449
- assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1450
- assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1451
-
1452
- # attend class - includes core attention algorithm + talking heads
1453
-
1454
- self.attend = Attend(
1455
- heads = heads,
1456
- causal = causal,
1457
- talking_heads = talking_heads,
1458
- dropout = dropout,
1459
- sparse_topk = sparse_topk,
1460
- qk_norm = qk_norm,
1461
- scale = qk_norm_scale if qk_norm else self.scale,
1462
- add_zero_kv = add_zero_kv,
1463
- flash = flash,
1464
- onnxable = onnxable
1465
- )
1466
-
1467
- # head scaling
1468
- self.head_scale = head_scale
1469
- if head_scale:
1470
- self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
1471
-
1472
- # explicit topk sparse attention
1473
- self.sparse_topk = sparse_topk
1474
-
1475
- # add memory key / values
1476
- self.num_mem_kv = num_mem_kv
1477
- if num_mem_kv > 0:
1478
- self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1479
- self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1480
-
1481
- # attention on attention
1482
- self.attn_on_attn = on_attn
1483
- self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1484
-
1485
- # whether to rotate positions into values, for absolute positions in addition to relative
1486
- self.rotary_embed_values = rotary_embed_values
1487
-
1488
- # init output projection 0
1489
- if zero_init_output:
1490
- init_zero_(self.to_out)
1491
-
1492
- def forward(
1493
- self,
1494
- x,
1495
- context = None,
1496
- mask = None,
1497
- context_mask = None,
1498
- attn_mask = None,
1499
- rel_pos = None,
1500
- rotary_pos_emb = None,
1501
- prev_attn = None,
1502
- mem = None,
1503
- return_intermediates = False,
1504
- cache: Optional[Intermediates] = None,
1505
- ):
1506
- b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
1507
- kv_input = default(context, x)
1508
-
1509
- q_input = x
1510
- k_input = kv_input
1511
- v_input = kv_input
1512
- r_input = x
1513
-
1514
- if exists(mem):
1515
- k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1516
- v_input, _ = pack([mem, v_input], 'b * d')
1517
-
1518
- q = self.to_q(q_input)
1519
- k = self.to_k(k_input)
1520
- v = self.to_v(v_input) if exists(self.to_v) else k
1521
- r = self.to_r(r_input) if exists(self.to_r) else None
1522
-
1523
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1524
-
1525
- k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1526
-
1527
- if exists(cache) and not has_context:
1528
- ck, cv = cache.cached_kv
1529
-
1530
- if exists(mem):
1531
- mk, k = unpack(k, mem_packed_shape, 'b h * d')
1532
- mv, v = unpack(v, mem_packed_shape, 'b h * d')
1533
-
1534
- k = torch.cat((ck, k), dim = -2)
1535
- v = torch.cat((cv, v), dim = -2)
1536
-
1537
- if exists(mem):
1538
- k = torch.cat((mk, k), dim = -2)
1539
- v = torch.cat((mv, v), dim = -2)
1540
-
1541
- if return_intermediates:
1542
- mem_len = mem.shape[-2] if exists(mem) else 0
1543
- cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1544
-
1545
- if self.qk_norm:
1546
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1547
- q, k = map(qk_l2norm, (q, k))
1548
- scale = self.qk_norm_scale
1549
-
1550
- q = q * self.qk_norm_q_scale
1551
- k = k * self.qk_norm_k_scale
1552
-
1553
- if exists(rotary_pos_emb) and not has_context:
1554
- freqs, xpos_scale = rotary_pos_emb
1555
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1556
-
1557
- q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1558
- k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1559
-
1560
- if self.rotary_embed_values:
1561
- v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1562
-
1563
- input_mask = context_mask
1564
-
1565
- if not exists(input_mask) and not has_context:
1566
- input_mask = mask
1567
-
1568
- if self.num_mem_kv > 0:
1569
- mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1570
-
1571
- if self.qk_norm:
1572
- mem_k = l2norm(mem_k)
1573
- mem_k = mem_k * self.qk_norm_k_scale
1574
-
1575
- k = torch.cat((mem_k, k), dim = -2)
1576
- v = torch.cat((mem_v, v), dim = -2)
1577
-
1578
- if exists(input_mask):
1579
- input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1580
-
1581
- i, j = map(lambda t: t.shape[-2], (q, k))
1582
-
1583
- # determine masking
1584
-
1585
- mask_value = max_neg_value(q)
1586
- masks = []
1587
- final_attn_mask = None
1588
-
1589
- if exists(input_mask):
1590
- input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1591
- masks.append(~input_mask)
1592
-
1593
- if exists(attn_mask):
1594
- assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1595
- if attn_mask.ndim == 2:
1596
- attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1597
- elif attn_mask.ndim == 3:
1598
- attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1599
- masks.append(~attn_mask)
1600
-
1601
- if exists(self.max_attend_past):
1602
- range_q = torch.arange(j - i, j, device = device)
1603
- range_k = torch.arange(j, device = device)
1604
- dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1605
- max_attend_past_mask = dist > self.max_attend_past
1606
- masks.append(max_attend_past_mask)
1607
-
1608
- if len(masks) > 0:
1609
- final_attn_mask = ~or_reduce(masks)
1610
-
1611
- # prepare relative positional bias, if needed
1612
-
1613
- attn_bias = None
1614
- if exists(rel_pos):
1615
- attn_bias = rel_pos(i, j)
1616
-
1617
- # attention is all we need
1618
-
1619
- out, intermediates = self.attend(
1620
- q, k, v,
1621
- mask = final_attn_mask,
1622
- attn_bias = attn_bias,
1623
- prev_attn = prev_attn
1624
- )
1625
-
1626
- # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1627
-
1628
- if exists(r):
1629
- out = out * r + out
1630
-
1631
- # normformer scaling of heads
1632
-
1633
- if head_scale:
1634
- out = out * self.head_scale_params
1635
-
1636
- # per head gating, from https://arxiv.org/abs/2306.12929
1637
-
1638
- if exists(self.to_v_head_gate):
1639
- head_gate = self.to_v_head_gate(x)
1640
- out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1641
-
1642
- # merge heads
1643
-
1644
- out = rearrange(out, 'b h n d -> b n (h d)')
1645
-
1646
- # alphafold2 styled gating of the values
1647
-
1648
- if exists(self.to_v_gate):
1649
- gates = self.to_v_gate(x)
1650
- out = out * gates.sigmoid()
1651
-
1652
- # combine the heads
1653
-
1654
- out = self.to_out(out)
1655
-
1656
- if exists(mask):
1657
- mask = rearrange(mask, 'b n -> b n 1')
1658
- out = out.masked_fill(~mask, 0.)
1659
-
1660
- if not return_intermediates:
1661
- return out
1662
-
1663
- intermediates.cached_kv = cached_kv
1664
-
1665
- return out, intermediates
1666
-
1667
- class AttentionLayers(nn.Module):
1668
- def __init__(
1669
- self,
1670
- dim,
1671
- depth,
1672
- heads = 8,
1673
- causal = False,
1674
- cross_attend = False,
1675
- only_cross = False,
1676
- use_scalenorm = False,
1677
- use_rmsnorm = False,
1678
- use_simple_rmsnorm = False,
1679
- alibi_pos_bias = False,
1680
- alibi_num_heads = None,
1681
- rel_pos_bias = False,
1682
- rel_pos_num_buckets = 32,
1683
- rel_pos_max_distance = 128,
1684
- dynamic_pos_bias = False,
1685
- dynamic_pos_bias_log_distance = False,
1686
- dynamic_pos_bias_mlp_depth = 2,
1687
- dynamic_pos_bias_norm = False,
1688
- rotary_pos_emb = False,
1689
- rotary_emb_dim = None,
1690
- rotary_xpos = False,
1691
- rotary_interpolation_factor = 1.,
1692
- rotary_xpos_scale_base = 512,
1693
- rotary_base_rescale_factor = 1.,
1694
- custom_layers = None,
1695
- sandwich_coef = None,
1696
- par_ratio = None,
1697
- weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1698
- layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1699
- residual_attn = False,
1700
- cross_residual_attn = False,
1701
- macaron = False,
1702
- pre_norm = True,
1703
- pre_norm_has_final_norm = True,
1704
- gate_residual = False,
1705
- scale_residual = False,
1706
- scale_residual_constant = 1.,
1707
- shift_tokens = 0,
1708
- sandwich_norm = False,
1709
- resi_dual = False,
1710
- resi_dual_scale = 1.,
1711
- zero_init_branch_output = False,
1712
- layer_dropout = 0.,
1713
- cross_attn_tokens_dropout = 0.,
1714
- **kwargs
1715
- ):
1716
- super().__init__()
1717
- rotary_pos_emb = rotary_pos_emb or rotary_xpos
1718
-
1719
- ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1720
- attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1721
-
1722
- dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1723
-
1724
- self.dim = dim
1725
- self.depth = depth
1726
- self.causal = causal
1727
- self.layers = nn.ModuleList([])
1728
-
1729
- self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1730
-
1731
- rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1732
-
1733
- assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1734
- self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1735
-
1736
- assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1737
- assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1738
-
1739
- # relative positional bias
1740
-
1741
- flash_attn = attn_kwargs.get('flash', False)
1742
- assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1743
-
1744
- self.rel_pos = None
1745
- if rel_pos_bias:
1746
- assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1747
- self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1748
- elif dynamic_pos_bias:
1749
- assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1750
- self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1751
- elif alibi_pos_bias:
1752
- alibi_num_heads = default(alibi_num_heads, heads)
1753
- assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1754
- self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1755
-
1756
- assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1757
- assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1758
-
1759
- if resi_dual:
1760
- pre_norm = False
1761
-
1762
- self.pre_norm = pre_norm
1763
- self.sandwich_norm = sandwich_norm
1764
-
1765
- self.resi_dual = resi_dual
1766
- assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1767
- self.resi_dual_scale = resi_dual_scale
1768
-
1769
- self.residual_attn = residual_attn
1770
- self.cross_residual_attn = cross_residual_attn
1771
- assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
1772
-
1773
- self.cross_attend = cross_attend
1774
-
1775
- assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1776
-
1777
- if use_scalenorm:
1778
- norm_class = ScaleNorm
1779
- elif use_rmsnorm:
1780
- norm_class = RMSNorm
1781
- elif use_simple_rmsnorm:
1782
- norm_class = SimpleRMSNorm
1783
- else:
1784
- norm_class = nn.LayerNorm
1785
-
1786
- norm_fn = partial(norm_class, dim)
1787
-
1788
- if cross_attend and not only_cross:
1789
- default_block = ('a', 'c', 'f')
1790
- elif cross_attend and only_cross:
1791
- default_block = ('c', 'f')
1792
- else:
1793
- default_block = ('a', 'f')
1794
-
1795
- if macaron:
1796
- default_block = ('f',) + default_block
1797
-
1798
- # zero init
1799
-
1800
- if zero_init_branch_output:
1801
- attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1802
- ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1803
-
1804
- # setup weight tying, which is a special case of `layer_execute_order`
1805
-
1806
- assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1807
-
1808
- if weight_tie_layers:
1809
- assert not exists(layers_execute_order)
1810
- layers_execute_order = tuple(range(len(default_block))) * depth
1811
- depth = 1
1812
-
1813
- # calculate layer block order
1814
-
1815
- if exists(custom_layers):
1816
- layer_types = custom_layers
1817
- elif exists(par_ratio):
1818
- par_depth = depth * len(default_block)
1819
- assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1820
- default_block = tuple(filter(not_equals('f'), default_block))
1821
- par_attn = par_depth // par_ratio
1822
- depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1823
- par_width = (depth_cut + depth_cut // par_attn) // par_attn
1824
- assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1825
- par_block = default_block + ('f',) * (par_width - len(default_block))
1826
- par_head = par_block * par_attn
1827
- layer_types = par_head + ('f',) * (par_depth - len(par_head))
1828
- elif exists(sandwich_coef):
1829
- assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1830
- layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1831
- else:
1832
- layer_types = default_block * depth
1833
-
1834
- self.layer_types = layer_types
1835
- self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1836
-
1837
- assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1838
-
1839
- self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1840
-
1841
- # stochastic depth
1842
-
1843
- self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1844
-
1845
- # structured dropout for cross attending
1846
-
1847
- self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1848
-
1849
- # calculate token shifting
1850
-
1851
- shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1852
-
1853
- # whether it has post norm
1854
-
1855
- self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1856
-
1857
- # iterate and construct layers
1858
-
1859
- for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1860
- is_last_layer = ind == (len(self.layer_types) - 1)
1861
-
1862
- if layer_type == 'a':
1863
- layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1864
- elif layer_type == 'c':
1865
- layer = Attention(dim, heads = heads, **attn_kwargs)
1866
- elif layer_type == 'f':
1867
- layer = FeedForward(dim, **ff_kwargs)
1868
- layer = layer if not macaron else Scale(0.5, layer)
1869
- else:
1870
- raise Exception(f'invalid layer type {layer_type}')
1871
-
1872
- if layer_shift_tokens > 0:
1873
- shift_range_upper = layer_shift_tokens + 1
1874
- shift_range_lower = -layer_shift_tokens if not causal else 0
1875
- layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1876
-
1877
- residual_fn = GRUGating if gate_residual else Residual
1878
- residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1879
-
1880
- pre_branch_norm = norm_fn() if pre_norm else None
1881
- post_branch_norm = norm_fn() if sandwich_norm else None
1882
- post_main_norm = norm_fn() if not pre_norm else None
1883
-
1884
- norms = nn.ModuleList([
1885
- pre_branch_norm,
1886
- post_branch_norm,
1887
- post_main_norm
1888
- ])
1889
-
1890
- self.layers.append(nn.ModuleList([
1891
- norms,
1892
- layer,
1893
- residual
1894
- ]))
1895
-
1896
- def forward(
1897
- self,
1898
- x,
1899
- context = None,
1900
- mask = None,
1901
- context_mask = None,
1902
- attn_mask = None,
1903
- self_attn_kv_mask = None,
1904
- mems = None,
1905
- seq_start_pos: Optional[Tensor] = None,
1906
- cache: Optional[LayerIntermediates] = None,
1907
- cache_age = 1,
1908
- return_hiddens = False
1909
- ):
1910
- assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1911
-
1912
- # initialize accums
1913
-
1914
- hiddens = []
1915
- layer_hiddens = []
1916
- intermediates = []
1917
-
1918
- prev_attn = None
1919
- prev_cross_attn = None
1920
-
1921
- mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1922
-
1923
- # handle left padded sequences
1924
-
1925
- if exists(seq_start_pos):
1926
- seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
1927
- left_pad_mask = seq_arange >= seq_start_pos[..., None]
1928
-
1929
- if exists(self_attn_kv_mask):
1930
- self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
1931
- else:
1932
- self_attn_kv_mask = left_pad_mask
1933
-
1934
- # rotary positions
1935
-
1936
- rotary_pos_emb = None
1937
-
1938
- if exists(self.rotary_pos_emb):
1939
- max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1940
- rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1941
-
1942
- # assume cached key / values
1943
-
1944
- attn_cache = []
1945
-
1946
- if exists(cache):
1947
- assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1948
-
1949
- if cache_age > 0:
1950
- x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1951
-
1952
- attn_cache = cache.attn_intermediates
1953
-
1954
- iter_attn_cache = iter(attn_cache)
1955
-
1956
- # outer residual - for resiDual paper
1957
-
1958
- outer_residual = x * self.resi_dual_scale
1959
-
1960
- # get layers to be executed
1961
-
1962
- layer_variables = (
1963
- self.layer_types,
1964
- self.layers,
1965
- self.layer_dropouts
1966
- )
1967
-
1968
- layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1969
-
1970
- # go through the attention and feedforward layers
1971
-
1972
- for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1973
- is_last = ind == (len(self.layers) - 1)
1974
-
1975
- if self.training and layer_dropout > 0. and random() < layer_dropout:
1976
- continue
1977
-
1978
- if layer_type == 'a':
1979
- if return_hiddens:
1980
- hiddens.append(x)
1981
- layer_mem = mems.pop(0) if mems else None
1982
-
1983
- if layer_type == 'c':
1984
- if self.training and self.cross_attn_tokens_dropout > 0.:
1985
- context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1986
-
1987
- inner_residual = x
1988
-
1989
- if return_hiddens:
1990
- layer_hiddens.append(x)
1991
-
1992
- pre_norm, post_branch_norm, post_main_norm = norm
1993
-
1994
- if exists(pre_norm):
1995
- x = pre_norm(x)
1996
-
1997
- if layer_type == 'a':
1998
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1999
- elif layer_type == 'c':
2000
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
2001
- elif layer_type == 'f':
2002
- out = block(x)
2003
-
2004
- if self.resi_dual:
2005
- outer_residual = outer_residual + out * self.resi_dual_scale
2006
-
2007
- if exists(post_branch_norm):
2008
- out = post_branch_norm(out)
2009
-
2010
- x = residual_fn(out, inner_residual)
2011
-
2012
- if layer_type in ('a', 'c') and return_hiddens:
2013
- intermediates.append(inter)
2014
-
2015
- if layer_type == 'a' and self.residual_attn:
2016
- prev_attn = inter.pre_softmax_attn
2017
- elif layer_type == 'c' and self.cross_residual_attn:
2018
- prev_cross_attn = inter.pre_softmax_attn
2019
-
2020
- if exists(post_main_norm):
2021
- x = post_main_norm(x)
2022
-
2023
- if return_hiddens:
2024
- layer_hiddens.append(x)
2025
-
2026
- if self.resi_dual:
2027
- x = x + self.final_norm(outer_residual)
2028
- else:
2029
- x = self.final_norm(x)
2030
-
2031
- if not return_hiddens:
2032
- return x
2033
-
2034
- intermediates = LayerIntermediates(
2035
- hiddens = hiddens,
2036
- attn_intermediates = intermediates,
2037
- layer_hiddens = layer_hiddens
2038
- )
2039
-
2040
- return x, intermediates
2041
-
2042
- class Encoder(AttentionLayers):
2043
- def __init__(self, **kwargs):
2044
- assert 'causal' not in kwargs, 'cannot set causality on encoder'
2045
- super().__init__(causal = False, **kwargs)
2046
-
2047
- class Decoder(AttentionLayers):
2048
- def __init__(self, **kwargs):
2049
- assert 'causal' not in kwargs, 'cannot set causality on decoder'
2050
- super().__init__(causal = True, **kwargs)
2051
-
2052
- class CrossAttender(AttentionLayers):
2053
- def __init__(self, **kwargs):
2054
- super().__init__(cross_attend = True, only_cross = True, **kwargs)
2055
-
2056
- class ViTransformerWrapper(nn.Module):
2057
- def __init__(
2058
- self,
2059
- *,
2060
- image_size,
2061
- patch_size,
2062
- attn_layers,
2063
- channels = 3,
2064
- num_classes = None,
2065
- post_emb_norm = False,
2066
- num_register_tokens = 0,
2067
- emb_dropout = 0.
2068
- ):
2069
- super().__init__()
2070
- assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
2071
- assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
2072
- dim = attn_layers.dim
2073
- num_patches = (image_size // patch_size) ** 2
2074
- patch_dim = channels * patch_size ** 2
2075
-
2076
- self.patch_size = patch_size
2077
-
2078
- self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
2079
-
2080
- has_register_tokens = num_register_tokens > 0
2081
- self.has_register_tokens = has_register_tokens
2082
-
2083
- if has_register_tokens:
2084
- self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
2085
-
2086
- self.patch_to_embedding = nn.Sequential(
2087
- nn.LayerNorm(patch_dim),
2088
- nn.Linear(patch_dim, dim),
2089
- nn.LayerNorm(dim)
2090
- )
2091
-
2092
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2093
- self.dropout = nn.Dropout(emb_dropout)
2094
-
2095
- self.attn_layers = attn_layers
2096
-
2097
- self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
2098
-
2099
- def forward(
2100
- self,
2101
- img,
2102
- return_embeddings = False
2103
- ):
2104
- b, p = img.shape[0], self.patch_size
2105
-
2106
- x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
2107
- x = self.patch_to_embedding(x)
2108
- n = x.shape[1]
2109
-
2110
- x = x + self.pos_embedding[:, :n]
2111
-
2112
- x = self.post_emb_norm(x)
2113
- x = self.dropout(x)
2114
-
2115
- if self.has_register_tokens:
2116
- r = repeat(self.register_tokens, 'n d -> b n d', b = b)
2117
- x, ps = pack((x, r), 'b * d')
2118
-
2119
- x = self.attn_layers(x)
2120
-
2121
- if self.has_register_tokens:
2122
- x, _ = unpack(x, ps, 'b * d')
2123
-
2124
- if not exists(self.mlp_head) or return_embeddings:
2125
- return x
2126
-
2127
- x = x.mean(dim = -2)
2128
- return self.mlp_head(x)
2129
-
2130
- class TransformerWrapper(nn.Module):
2131
- def __init__(
2132
- self,
2133
- *,
2134
- num_tokens,
2135
- max_seq_len,
2136
- attn_layers,
2137
- emb_dim = None,
2138
- max_mem_len = 0,
2139
- shift_mem_down = 0,
2140
- emb_dropout = 0.,
2141
- post_emb_norm = False,
2142
- num_memory_tokens = None,
2143
- memory_tokens_interspersed_every = None,
2144
- tie_embedding = False,
2145
- logits_dim = None,
2146
- use_abs_pos_emb = True,
2147
- scaled_sinu_pos_emb = False,
2148
- l2norm_embed = False,
2149
- emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
2150
- attn_z_loss_weight = 1e-4,
2151
- ):
2152
- super().__init__()
2153
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2154
-
2155
- dim = attn_layers.dim
2156
- emb_dim = default(emb_dim, dim)
2157
- self.emb_dim = emb_dim
2158
- self.num_tokens = num_tokens
2159
-
2160
- self.max_seq_len = max_seq_len
2161
- self.max_mem_len = max_mem_len
2162
- self.shift_mem_down = shift_mem_down
2163
-
2164
- self.l2norm_embed = l2norm_embed
2165
- self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
2166
-
2167
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2168
- self.pos_emb = always(0)
2169
- elif scaled_sinu_pos_emb:
2170
- self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
2171
- else:
2172
- self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
2173
-
2174
- self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2175
-
2176
- self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
2177
- self.emb_dropout = nn.Dropout(emb_dropout)
2178
-
2179
- self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
2180
- self.attn_layers = attn_layers
2181
-
2182
- self.init_()
2183
-
2184
- logits_dim = default(logits_dim, num_tokens)
2185
- self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
2186
-
2187
- # memory tokens (like [cls]) from Memory Transformers paper
2188
-
2189
- num_memory_tokens = default(num_memory_tokens, 0)
2190
- self.num_memory_tokens = num_memory_tokens
2191
- if num_memory_tokens > 0:
2192
- self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
2193
-
2194
- self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2195
-
2196
- # whether can do cached kv decoding
2197
-
2198
- self.can_cache_kv = self.num_memory_tokens == 0
2199
-
2200
- def init_(self):
2201
- if self.l2norm_embed:
2202
- nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2203
- if not isinstance(self.pos_emb, always):
2204
- nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2205
- return
2206
-
2207
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
2208
-
2209
- def forward(
2210
- self,
2211
- x,
2212
- return_embeddings = False,
2213
- return_logits_and_embeddings = False,
2214
- return_intermediates = False,
2215
- mask = None,
2216
- return_mems = False,
2217
- return_attn = False,
2218
- mems = None,
2219
- pos = None,
2220
- prepend_embeds = None,
2221
- sum_embeds = None,
2222
- return_attn_z_loss = False,
2223
- attn_z_loss_weight = 1e-4,
2224
- seq_start_pos = None,
2225
- cache: Optional[LayerIntermediates] = None,
2226
- **kwargs
2227
- ):
2228
- b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2229
- return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2230
-
2231
- # absolute positional embedding
2232
-
2233
- external_pos_emb = exists(pos) and pos.dtype != torch.long
2234
- pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2235
- x = self.token_emb(x) + pos_emb
2236
-
2237
- # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
2238
-
2239
- if exists(sum_embeds):
2240
- x = x + sum_embeds
2241
-
2242
- # post embedding norm, purportedly leads to greater stabilization
2243
-
2244
- x = self.post_emb_norm(x)
2245
-
2246
- # whether to append embeds, as in PaLI, for image embeddings
2247
-
2248
- if exists(prepend_embeds):
2249
- prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2250
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2251
-
2252
- x = torch.cat((prepend_embeds, x), dim = -2)
2253
-
2254
- # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2255
-
2256
- if emb_frac_gradient < 1:
2257
- assert emb_frac_gradient > 0
2258
- x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
2259
-
2260
- # embedding dropout
2261
-
2262
- x = self.emb_dropout(x)
2263
-
2264
- x = self.project_emb(x)
2265
-
2266
- if has_memory_tokens:
2267
- mem_every = self.memory_tokens_interspersed_every
2268
-
2269
- if exists(mem_every):
2270
- assert mem_every > 0
2271
- assert isinstance(self.attn_layers, Decoder), 'only for decoder'
2272
- next_seq_len = math.ceil(n / mem_every) * mem_every
2273
-
2274
- x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
2275
- x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
2276
-
2277
- mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
2278
- x, mem_packed_shape = pack((mem, x), 'b * d')
2279
-
2280
- # auto-handle masking after appending memory tokens
2281
- if not exists(mem_every) and exists(mask):
2282
- mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
2283
-
2284
- if exists(mem_every):
2285
- x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2286
-
2287
- if self.shift_mem_down and exists(mems):
2288
- mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2289
- mems = [*mems_r, *mems_l]
2290
-
2291
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2292
-
2293
- if has_memory_tokens:
2294
- if exists(mem_every):
2295
- x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
2296
-
2297
- mem, x = unpack(x, mem_packed_shape, 'b * d')
2298
-
2299
- if exists(mem_every):
2300
- x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2301
-
2302
- x = x[:, :n]
2303
-
2304
- if return_logits_and_embeddings:
2305
- out = (self.to_logits(x), x)
2306
- elif return_embeddings:
2307
- out = x
2308
- else:
2309
- out = self.to_logits(x)
2310
-
2311
- if return_attn_z_loss:
2312
- pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2313
- intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2314
- return_intermediates = True
2315
-
2316
- if return_mems:
2317
- hiddens = intermediates.hiddens
2318
- new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2319
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2320
-
2321
- if not return_intermediates:
2322
- return out, new_mems
2323
-
2324
- intermediates.mems = new_mems
2325
-
2326
- if return_intermediates:
2327
- return out, intermediates
2328
-
2329
- if return_attn:
2330
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2331
- return out, attn_maps
2332
-
2333
- return out
2334
-
2335
- class ContinuousTransformerWrapper(nn.Module):
2336
- def __init__(
2337
- self,
2338
- *,
2339
- max_seq_len,
2340
- attn_layers,
2341
- dim_in = None,
2342
- dim_out = None,
2343
- emb_dim = None,
2344
- max_mem_len = 0,
2345
- post_emb_norm = False,
2346
- emb_dropout = 0.,
2347
- use_abs_pos_emb = True,
2348
- scaled_sinu_pos_emb = False
2349
- ):
2350
- super().__init__()
2351
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2352
-
2353
- dim = attn_layers.dim
2354
-
2355
- self.max_seq_len = max_seq_len
2356
-
2357
- self.max_mem_len = max_mem_len
2358
-
2359
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2360
- self.pos_emb = always(0)
2361
- elif scaled_sinu_pos_emb:
2362
- self.pos_emb = ScaledSinusoidalEmbedding(dim)
2363
- else:
2364
- self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
2365
-
2366
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2367
- self.emb_dropout = nn.Dropout(emb_dropout)
2368
-
2369
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
2370
-
2371
- self.attn_layers = attn_layers
2372
-
2373
- self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
2374
-
2375
- def forward(
2376
- self,
2377
- x,
2378
- return_embeddings = False,
2379
- return_intermediates = False,
2380
- return_mems = False,
2381
- mask = None,
2382
- return_attn = False,
2383
- mems = None,
2384
- pos = None,
2385
- prepend_embeds = None,
2386
- **kwargs
2387
- ):
2388
- x = self.project_in(x)
2389
- x = x + self.pos_emb(x, pos = pos)
2390
-
2391
- x = self.post_emb_norm(x)
2392
-
2393
- # whether to append embeds, as in PaLI, for image embeddings
2394
-
2395
- if exists(prepend_embeds):
2396
- _, prepend_dim = prepend_embeds.shape[1:]
2397
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
2398
-
2399
- x = torch.cat((prepend_embeds, x), dim = -2)
2400
-
2401
- x = self.emb_dropout(x)
2402
-
2403
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
2404
-
2405
- out = self.project_out(x) if not return_embeddings else x
2406
-
2407
- if return_intermediates:
2408
- return out, intermediates
2409
-
2410
- if return_mems:
2411
- hiddens = intermediates.hiddens
2412
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
2413
- return out, new_mems
2414
-
2415
- if return_attn:
2416
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2417
- return out, attn_maps
2418
-
2419
- return out
2420
-
2421
- class XTransformer(nn.Module):
2422
- def __init__(
2423
- self,
2424
- *,
2425
- dim,
2426
- tie_token_emb = False,
2427
- ignore_index = -100,
2428
- pad_value = 0,
2429
- cross_attn_tokens_dropout = 0.,
2430
- **kwargs
2431
- ):
2432
- super().__init__()
2433
- enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
2434
- dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
2435
-
2436
- assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
2437
- enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
2438
- enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
2439
- enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
2440
- enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
2441
- enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
2442
-
2443
- dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
2444
- dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
2445
- dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
2446
- dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
2447
-
2448
- self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
2449
-
2450
- self.encoder = TransformerWrapper(
2451
- **enc_transformer_kwargs,
2452
- attn_layers = Encoder(dim = dim, **enc_kwargs)
2453
- )
2454
-
2455
- self.decoder = TransformerWrapper(
2456
- **dec_transformer_kwargs,
2457
- attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
2458
- )
2459
-
2460
- if tie_token_emb:
2461
- self.decoder.token_emb = self.encoder.token_emb
2462
-
2463
- self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
2464
-
2465
- @torch.no_grad()
2466
- def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
2467
- encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
2468
- return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
2469
-
2470
- def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
2471
-
2472
- if exists(src_prepend_embeds) and exists(mask):
2473
- mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
2474
-
2475
- enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
2476
-
2477
- if self.training and self.cross_attn_tokens_dropout > 0:
2478
- enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2479
-
2480
- out = self.decoder(tgt, context = enc, context_mask = mask)
2481
- return out