asigalov61 commited on
Commit
b43c9c0
·
verified ·
1 Parent(s): d236a0b

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +17 -6
  2. TMIDIX.py +0 -0
  3. app.py +499 -0
  4. midi_to_colab_audio.py +0 -0
  5. packages.txt +1 -0
  6. requirements.txt +5 -0
  7. x_transformer_1_23_2.py +2481 -0
README.md CHANGED
@@ -1,14 +1,25 @@
1
  ---
2
  title: MIDI Loops Mixer
3
- emoji: 📉
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.13.2
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
- short_description: Mix random MIDI loops into one coherent MIDI composition
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MIDI Loops Mixer
3
+ emoji: 🎨
4
+ colorFrom: green
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.13.2
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
+ short_description: Mix random MIDI loops into one coherent composition
12
+ tags:
13
+ - music
14
+ - music ai
15
+ - music transformer
16
+ - MIDI
17
+ - guided
18
+ - accompaniment
19
+ - accompaniment generation
20
+ - accompaniment transformer
21
+ thumbnail: >-
22
+ https://cdn-uploads.huggingface.co/production/uploads/5f57ea2d3f32f12a3c0692e6/RvzjdORKBps7rkYYfsbVo.jpeg
23
  ---
24
 
25
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
TMIDIX.py ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #==================================================================================
2
+ # https://huggingface.co/spaces/asigalov61/MIDI-Loops-Mixer
3
+ #==================================================================================
4
+
5
+ print('=' * 70)
6
+ print('MIDI Loops Mixer Gradio App')
7
+
8
+ print('=' * 70)
9
+ print('Loading core MIDI Loops Mixer modules...')
10
+
11
+ import os
12
+ import copy
13
+
14
+ import time as reqtime
15
+ import datetime
16
+ from pytz import timezone
17
+
18
+ print('=' * 70)
19
+ print('Loading main MIDI Loops Mixer modules...')
20
+
21
+ os.environ['USE_FLASH_ATTENTION'] = '1'
22
+
23
+ import torch
24
+
25
+ torch.set_float32_matmul_precision('high')
26
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
29
+ torch.backends.cuda.enable_math_sdp(True)
30
+ torch.backends.cuda.enable_flash_sdp(True)
31
+ torch.backends.cuda.enable_cudnn_sdp(True)
32
+
33
+ from huggingface_hub import hf_hub_download
34
+
35
+ import TMIDIX
36
+
37
+ from midi_to_colab_audio import midi_to_colab_audio
38
+
39
+ from x_transformer_1_23_2 import *
40
+
41
+ import random
42
+
43
+ import tqdm
44
+
45
+ print('=' * 70)
46
+ print('Loading aux MIDI Loops Mixer modules...')
47
+
48
+ import matplotlib.pyplot as plt
49
+
50
+ import gradio as gr
51
+ import spaces
52
+
53
+ print('=' * 70)
54
+ print('PyTorch version:', torch.__version__)
55
+ print('=' * 70)
56
+ print('Done!')
57
+ print('Enjoy! :)')
58
+ print('=' * 70)
59
+
60
+ #==================================================================================
61
+
62
+ MODEL_CHECKPOINT = 'Guided_Accompaniment_Transformer_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
63
+
64
+ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
65
+
66
+ #==================================================================================
67
+
68
+ print('=' * 70)
69
+ print('Instantiating model...')
70
+
71
+ device_type = 'cuda'
72
+ dtype = 'bfloat16'
73
+
74
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
75
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
76
+
77
+ SEQ_LEN = 4096
78
+ PAD_IDX = 1794
79
+
80
+ model = TransformerWrapper(
81
+ num_tokens = PAD_IDX+1,
82
+ max_seq_len = SEQ_LEN,
83
+ attn_layers = Decoder(dim = 2048,
84
+ depth = 4,
85
+ heads = 32,
86
+ rotary_pos_emb = True,
87
+ attn_flash = True
88
+ )
89
+ )
90
+
91
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
92
+
93
+ print('=' * 70)
94
+ print('Loading model checkpoint...')
95
+
96
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/MIDI-Loops-Mixer', filename=MODEL_CHECKPOINT)
97
+
98
+ model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
99
+
100
+ model = torch.compile(model, mode='max-autotune')
101
+
102
+ print('=' * 70)
103
+ print('Done!')
104
+ print('=' * 70)
105
+ print('Model will use', dtype, 'precision...')
106
+ print('=' * 70)
107
+
108
+ #==================================================================================
109
+
110
+ def load_midi(input_midi, melody_patch=-1):
111
+
112
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
113
+
114
+ escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
115
+ escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
116
+
117
+ sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False)
118
+
119
+ if melody_patch == -1:
120
+ zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
121
+
122
+ else:
123
+ mel_score = [e for e in sp_escore_notes if e[6] == melody_patch]
124
+
125
+ if mel_score:
126
+ zscore = TMIDIX.recalculate_score_timings(mel_score)
127
+
128
+ else:
129
+ zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
130
+
131
+ cscore = TMIDIX.chordify_score([1000, zscore])
132
+
133
+ score = []
134
+
135
+ score_list = []
136
+
137
+ pc = cscore[0]
138
+
139
+ for c in cscore:
140
+ score.append(max(0, min(127, c[0][1]-pc[0][1])))
141
+
142
+ scl = [[max(0, min(127, c[0][1]-pc[0][1]))]]
143
+
144
+ n = c[0]
145
+
146
+ score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
147
+ scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
148
+
149
+ score_list.append(scl)
150
+
151
+ pc = c
152
+
153
+ score_list.append(scl)
154
+
155
+ return score, score_list
156
+
157
+ #==================================================================================
158
+
159
+ @spaces.GPU
160
+ def Generate_Accompaniment(input_midi,
161
+ generation_type,
162
+ melody_patch,
163
+ model_temperature
164
+ ):
165
+
166
+ #===============================================================================
167
+
168
+ def generate_full_seq(input_seq, temperature=0.9, verbose=True):
169
+
170
+ seq_abs_run_time = sum([t for t in input_seq if t < 128])
171
+
172
+ cur_time = 0
173
+
174
+ full_seq = copy.deepcopy(input_seq)
175
+
176
+ toks_counter = 0
177
+
178
+ while cur_time <= seq_abs_run_time:
179
+
180
+ if verbose:
181
+ if toks_counter % 128 == 0:
182
+ print('Generated', toks_counter, 'tokens')
183
+
184
+ x = torch.LongTensor(full_seq).cuda()
185
+
186
+ with ctx:
187
+ out = model.generate(x,
188
+ 1,
189
+ temperature=temperature,
190
+ return_prime=False,
191
+ verbose=False)
192
+
193
+ y = out.tolist()[0][0]
194
+
195
+ if y < 128:
196
+ cur_time += y
197
+
198
+ full_seq.append(y)
199
+
200
+ toks_counter += 1
201
+
202
+ return full_seq
203
+
204
+ #===============================================================================
205
+
206
+ def generate_block_seq(input_seq, trg_dtime, temperature=0.9):
207
+
208
+ inp_seq = copy.deepcopy(input_seq)
209
+
210
+ block_seq = []
211
+
212
+ cur_time = 0
213
+
214
+ while cur_time < trg_dtime:
215
+
216
+ x = torch.LongTensor(inp_seq).cuda()
217
+
218
+ with ctx:
219
+ out = model.generate(x,
220
+ 1,
221
+ temperature=temperature,
222
+ return_prime=False,
223
+ verbose=False)
224
+
225
+ y = out.tolist()[0][0]
226
+
227
+ if y < 128:
228
+ cur_time += y
229
+
230
+ inp_seq.append(y)
231
+ block_seq.append(y)
232
+
233
+ if cur_time != trg_dtime:
234
+ return []
235
+
236
+ else:
237
+ return block_seq
238
+
239
+ #===============================================================================
240
+
241
+ print('=' * 70)
242
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
243
+ start_time = reqtime.time()
244
+ print('=' * 70)
245
+
246
+ fn = os.path.basename(input_midi)
247
+ fn1 = fn.split('.')[0]
248
+
249
+ print('=' * 70)
250
+ print('Requested settings:')
251
+ print('=' * 70)
252
+ print('Input MIDI file name:', fn)
253
+ print('Generation type:', generation_type)
254
+ print('Source melody patch:', melody_patch)
255
+ print('Model temperature:', model_temperature)
256
+
257
+ print('=' * 70)
258
+
259
+ #==================================================================
260
+
261
+ score, score_list = load_midi(input_midi.name)
262
+
263
+ print('Sample score events', score[:12])
264
+
265
+ #==================================================================
266
+
267
+ print('=' * 70)
268
+ print('Generating...')
269
+
270
+ model.to(device_type)
271
+ model.eval()
272
+
273
+ #==================================================================
274
+
275
+ start_score_seq = [1792] + score + [1793]
276
+
277
+ #==================================================================
278
+
279
+ if generation_type == 'Guided':
280
+
281
+ input_seq = []
282
+
283
+ input_seq.extend(start_score_seq)
284
+ input_seq.extend(score_list[0][0])
285
+
286
+ block_seq_lens = []
287
+
288
+ idx = 0
289
+
290
+ max_retries = 3
291
+ mrt = 0
292
+
293
+ while idx < len(score_list)-1:
294
+
295
+ if idx % 10 == 0:
296
+ print('Generating', idx, 'block')
297
+
298
+ input_seq.extend(score_list[idx][1])
299
+
300
+ block_seq = []
301
+
302
+ for _ in range(max_retries):
303
+
304
+ block_seq = generate_block_seq(input_seq, score_list[idx+1][0][0])
305
+
306
+ if block_seq:
307
+ break
308
+
309
+ if block_seq:
310
+ input_seq.extend(block_seq)
311
+ block_seq_lens.append(len(block_seq))
312
+ idx += 1
313
+ mrt = 0
314
+
315
+ else:
316
+
317
+ if block_seq_lens:
318
+ input_seq = input_seq[:-(block_seq_lens[-1]+2)]
319
+ block_seq_lens.pop()
320
+ idx -= 1
321
+ mrt += 1
322
+
323
+ else:
324
+ break
325
+
326
+ if mrt == max_retries:
327
+ break
328
+
329
+ else:
330
+ input_seq = generate_full_seq(start_score_seq, temperature=model_temperature)
331
+
332
+ final_song = input_seq[len(start_score_seq):]
333
+
334
+ print('=' * 70)
335
+ print('Done!')
336
+ print('=' * 70)
337
+
338
+ #===============================================================================
339
+
340
+ print('Rendering results...')
341
+
342
+ print('=' * 70)
343
+ print('Sample INTs', final_song[:15])
344
+ print('=' * 70)
345
+
346
+ song_f = []
347
+
348
+ if len(final_song) != 0:
349
+
350
+ time = 0
351
+ dur = 0
352
+ vel = 90
353
+ pitch = 0
354
+ channel = 0
355
+ patch = 0
356
+
357
+ channels_map = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 9, 12, 13, 14, 15]
358
+ patches_map = [40, 0, 10, 19, 24, 35, 40, 52, 56, 9, 65, 73, 0, 0, 0, 0]
359
+ velocities_map = [125, 80, 100, 80, 90, 100, 100, 80, 110, 110, 110, 110, 80, 80, 80, 80]
360
+
361
+ for m in final_song:
362
+
363
+ if 0 <= m < 128:
364
+ time += m * 32
365
+
366
+ elif 128 < m < 256:
367
+ dur = (m-128) * 32
368
+
369
+ elif 256 < m < 1792:
370
+ cha = (m-256) // 128
371
+ pitch = (m-256) % 128
372
+
373
+ channel = channels_map[cha]
374
+ patch = patches_map[channel]
375
+ vel = velocities_map[channel]
376
+
377
+ song_f.append(['note', time, dur, channel, pitch, vel, patch])
378
+
379
+ fn1 = "MIDI-Loops-Mixer-Composition"
380
+
381
+ detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
382
+ output_signature = 'MIDI Loops Mixer',
383
+ output_file_name = fn1,
384
+ track_name='Project Los Angeles',
385
+ list_of_MIDI_patches=patches_map
386
+ )
387
+
388
+ new_fn = fn1+'.mid'
389
+
390
+
391
+ audio = midi_to_colab_audio(new_fn,
392
+ soundfont_path=SOUDFONT_PATH,
393
+ sample_rate=16000,
394
+ volume_scale=10,
395
+ output_for_gradio=True
396
+ )
397
+
398
+ print('Done!')
399
+ print('=' * 70)
400
+
401
+ #========================================================
402
+
403
+ output_midi = str(new_fn)
404
+ output_audio = (16000, audio)
405
+
406
+ output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
407
+
408
+ print('Output MIDI file name:', output_midi)
409
+ print('=' * 70)
410
+
411
+ #========================================================
412
+
413
+ print('-' * 70)
414
+ print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
415
+ print('-' * 70)
416
+ print('Req execution time:', (reqtime.time() - start_time), 'sec')
417
+
418
+ return output_audio, output_plot, output_midi
419
+
420
+ #==================================================================================
421
+
422
+ PDT = timezone('US/Pacific')
423
+
424
+ print('=' * 70)
425
+ print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
426
+ print('=' * 70)
427
+
428
+ #==================================================================================
429
+
430
+ with gr.Blocks() as demo:
431
+
432
+ #==================================================================================
433
+
434
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Loops Mixer</h1>")
435
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Guided melody accompaniment generation with transformers</h1>")
436
+ gr.HTML("""
437
+ <p>
438
+ <a href="https://huggingface.co/spaces/asigalov61/MIDI-Loops-Mixer?duplicate=true">
439
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
440
+ </a>
441
+ </p>
442
+
443
+ for faster execution and endless generation!
444
+ """)
445
+
446
+ #==================================================================================
447
+
448
+ gr.Markdown("## Upload source melody MIDI or select an example MIDI below")
449
+
450
+ input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
451
+
452
+ gr.Markdown("## Generation options")
453
+
454
+ generation_type = gr.Radio(["Guided", "Freestyle"], value="Guided", label="Generation type")
455
+ melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
456
+ model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
457
+
458
+ generate_btn = gr.Button("Generate", variant="primary")
459
+
460
+ gr.Markdown("## Generation results")
461
+
462
+ output_audio = gr.Audio(label="MIDI audio", format="wav", elem_id="midi_audio")
463
+ output_plot = gr.Plot(label="MIDI score plot")
464
+ output_midi = gr.File(label="MIDI file", file_types=[".mid"])
465
+
466
+ generate_btn.click(Generate_Accompaniment,
467
+ [input_midi,
468
+ generation_type,
469
+ melody_patch,
470
+ model_temperature
471
+ ],
472
+ [output_audio,
473
+ output_plot,
474
+ output_midi
475
+ ]
476
+ )
477
+
478
+ gr.Examples(
479
+ [["USSR-National-Anthem-Seed-Melody.mid", "Guided", -1, 0.9],
480
+ ["Hotel-California-Seed-Melody.mid", "Guided", -1, 0.9],
481
+ ["Sparks-Fly-Seed-Melody.mid", "Guided", -1, 0.9]
482
+ ],
483
+ [input_midi,
484
+ generation_type,
485
+ melody_patch,
486
+ model_temperature
487
+ ],
488
+ [output_audio,
489
+ output_plot,
490
+ output_midi
491
+ ],
492
+ Generate_Accompaniment
493
+ )
494
+
495
+ #==================================================================================
496
+
497
+ demo.launch()
498
+
499
+ #==================================================================================
midi_to_colab_audio.py ADDED
The diff for this file is too large to render. See raw diff
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface_hub
2
+ gradio
3
+ matplotlib
4
+ numpy
5
+ tqdm
x_transformer_1_23_2.py ADDED
@@ -0,0 +1,2481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #===================================================================================================================
2
+ #
3
+ # X Trasformer Module
4
+ #
5
+ # Partial x-transformers code With useful modifications
6
+ #
7
+ # Version 1.0
8
+ #
9
+ # Original source code courtesy of lucidrains
10
+ # https://github.com/lucidrains/x-transformers
11
+ #
12
+ # Original source code retrieved on 10/10/2023
13
+ #
14
+ # Project Los Angeles
15
+ # Tegridy Code 2023
16
+
17
+ #===================================================================================================================
18
+
19
+ # Critical dependencies
20
+ #
21
+ # !pip install torch
22
+ # !pip install einops
23
+
24
+ #===================================================================================================================
25
+
26
+ from functools import partial
27
+ from typing import Optional, Tuple
28
+
29
+ import os
30
+ os.environ['USE_FLASH_ATTENTION'] = '1'
31
+
32
+ import torch
33
+ from torch import nn, einsum, Tensor
34
+ import torch.nn.functional as F
35
+
36
+ # Flash attention
37
+ from torch.nn.attention import SDPBackend, sdpa_kernel
38
+ torch.backends.cuda.enable_flash_sdp(True)
39
+
40
+ from collections import namedtuple
41
+ from functools import wraps
42
+ from packaging import version
43
+ from dataclasses import dataclass
44
+
45
+ from einops import rearrange, repeat
46
+
47
+ # constants
48
+
49
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
50
+
51
+ @dataclass
52
+ class Intermediates:
53
+ qk_similarities: Optional[Tensor] = None
54
+ pre_softmax_attn: Optional[Tensor] = None
55
+ post_softmax_attn: Optional[Tensor] = None
56
+ cached_kv: Optional[Tuple[Tensor, Tensor]] = None
57
+
58
+ def to_tuple(self):
59
+ return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
60
+
61
+ # helpers
62
+
63
+ def exists(val):
64
+ return val is not None
65
+
66
+ def default(val, d):
67
+ return val if exists(val) else d
68
+
69
+ def compact(arr):
70
+ return [*filter(exists, arr)]
71
+
72
+ def once(fn):
73
+ called = False
74
+ @wraps(fn)
75
+ def inner(x):
76
+ nonlocal called
77
+ if called:
78
+ return
79
+ called = True
80
+ return fn(x)
81
+ return inner
82
+
83
+ print_once = once(print)
84
+
85
+ # functions for creating causal mask
86
+ # need a special one for onnx cpu (no support for .triu)
87
+
88
+ def create_causal_mask(i, j, device):
89
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
90
+
91
+ def onnx_create_causal_mask(i, j, device):
92
+ r = torch.arange(i, device = device)
93
+ causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
94
+ causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
95
+ return causal_mask
96
+
97
+ # main class
98
+
99
+ class Attend(nn.Module):
100
+ def __init__(
101
+ self,
102
+ *,
103
+ dropout = 0.,
104
+ causal = False,
105
+ heads = None,
106
+ talking_heads = False,
107
+ sparse_topk = None,
108
+ scale = None,
109
+ qk_norm = False,
110
+ flash = False,
111
+ add_zero_kv = False,
112
+ onnxable = False
113
+ ):
114
+ super().__init__()
115
+ self.scale = scale
116
+ self.qk_norm = qk_norm
117
+
118
+ self.causal = causal
119
+ self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
120
+
121
+ self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
122
+
123
+ self.dropout = dropout
124
+ self.attn_dropout = nn.Dropout(dropout)
125
+
126
+ # talking heads
127
+
128
+ assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
129
+
130
+ self.talking_heads = talking_heads
131
+ if talking_heads:
132
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
133
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
134
+
135
+ # sparse topk
136
+
137
+ assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
138
+ self.sparse_topk = sparse_topk
139
+
140
+ # add a key / value token composed of zeros
141
+ # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
142
+
143
+ self.add_zero_kv = add_zero_kv
144
+
145
+ # flash attention
146
+
147
+ self.flash = flash
148
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
149
+
150
+ # determine efficient attention configs for cuda and cpu
151
+
152
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
153
+ self.cuda_config = None
154
+
155
+ if not torch.cuda.is_available() or not flash:
156
+ return
157
+
158
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
159
+
160
+ major, minor = device_properties.major, device_properties.minor
161
+
162
+ if (major, minor) == (8, 0):
163
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
164
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
165
+ elif (major, minor) == (9, 0):
166
+ print_once('H100 GPU detected, using flash attention')
167
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
168
+ else:
169
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
170
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
171
+
172
+ def flash_attn(
173
+ self,
174
+ q, k, v,
175
+ mask = None,
176
+ attn_bias = None
177
+ ):
178
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
179
+
180
+ # Recommended for multi-query single-key-value attention by Tri Dao
181
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
182
+
183
+ if k.ndim == 3:
184
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
185
+
186
+ if v.ndim == 3:
187
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
188
+
189
+ # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
190
+
191
+ if self.qk_norm:
192
+ default_scale = q.shape[-1] ** -0.5
193
+ q = q * (self.scale / default_scale)
194
+
195
+ # Check if mask exists and expand to compatible shape
196
+ # The mask is B L, so it would have to be expanded to B H N L
197
+
198
+ causal = self.causal
199
+
200
+ # in the case of kv caching with one token (q_len == 1), just turn off causal masking
201
+ # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
202
+
203
+ if q_len == 1 and causal:
204
+ causal = False
205
+
206
+ # expand key padding mask
207
+
208
+ if exists(mask):
209
+ assert mask.ndim == 4
210
+ mask = mask.expand(batch, heads, q_len, k_len)
211
+
212
+ # handle kv cache - this should be bypassable in updated flash attention 2
213
+
214
+ if k_len > q_len and causal:
215
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
216
+ if not exists(mask):
217
+ mask = ~causal_mask
218
+ else:
219
+ mask = mask & ~causal_mask
220
+ causal = False
221
+
222
+ # manually handle causal mask, if another mask was given
223
+
224
+ row_is_entirely_masked = None
225
+
226
+ if exists(mask) and causal:
227
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
228
+ mask = mask & ~causal_mask
229
+
230
+ # protect against an entire row being masked out
231
+
232
+ row_is_entirely_masked = ~mask.any(dim = -1)
233
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
234
+
235
+ causal = False
236
+
237
+ # handle alibi positional bias
238
+ # convert from bool to float
239
+
240
+ if exists(attn_bias):
241
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
242
+
243
+ # if mask given, the mask would already contain the causal mask from above logic
244
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
245
+
246
+ mask_value = -torch.finfo(q.dtype).max
247
+
248
+ if exists(mask):
249
+ attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
250
+ elif causal:
251
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
252
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
253
+ causal = False
254
+
255
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
256
+ # make it an additive bias here
257
+
258
+ mask = attn_bias
259
+
260
+ # Check if there is a compatible device for flash attention
261
+
262
+ config = self.cuda_config if is_cuda else self.cpu_config
263
+
264
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
265
+
266
+ # Legacy code...
267
+ # with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
268
+ # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
269
+
270
+ # PyTorch 2.3-2.4 SDPA backend code...
271
+ # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
272
+ with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
273
+
274
+ # New PyTorch 2.5 SDPA backend code:
275
+ # with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
276
+
277
+ out = F.scaled_dot_product_attention(
278
+ q, k, v,
279
+ attn_mask = mask,
280
+ dropout_p = self.dropout if self.training else 0.,
281
+ is_causal = causal
282
+ )
283
+
284
+ # for a row that is entirely masked out, should zero out the output of that row token
285
+
286
+ if exists(row_is_entirely_masked):
287
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
288
+
289
+ return out, Intermediates()
290
+
291
+ def forward(
292
+ self,
293
+ q, k, v,
294
+ mask = None,
295
+ attn_bias = None,
296
+ prev_attn = None
297
+ ):
298
+ """
299
+ einstein notation
300
+ b - batch
301
+ h - heads
302
+ n, i, j - sequence length (base sequence length, source, target)
303
+ d - feature dimension
304
+ """
305
+
306
+ n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
307
+
308
+ scale = default(self.scale, q.shape[-1] ** -0.5)
309
+
310
+ causal = self.causal
311
+
312
+ # handle kv cached decoding
313
+
314
+ if n == 1 and causal:
315
+ causal = False
316
+
317
+ # handle grouped multi-query attention
318
+
319
+ if kv_heads == 1:
320
+ k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
321
+ elif kv_heads < heads:
322
+ k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
323
+
324
+ # handle zero kv, as means for allowing network to attend to nothing
325
+
326
+ if self.add_zero_kv:
327
+ k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
328
+
329
+ if exists(mask):
330
+ mask = F.pad(mask, (1, 0), value = True)
331
+
332
+ if exists(attn_bias):
333
+ attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
334
+
335
+ if self.flash:
336
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
337
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
338
+
339
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
340
+
341
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
342
+
343
+ if exists(prev_attn):
344
+ dots = dots + prev_attn
345
+
346
+ qk_similarities = dots.clone()
347
+
348
+ if self.talking_heads:
349
+ dots = self.pre_softmax_talking_heads(dots)
350
+
351
+ if exists(attn_bias):
352
+ dots = dots + attn_bias
353
+
354
+ i, j, dtype = *dots.shape[-2:], dots.dtype
355
+
356
+ mask_value = -torch.finfo(dots.dtype).max
357
+
358
+ if exists(self.sparse_topk) and self.sparse_topk < j:
359
+ top_values, _ = dots.topk(self.sparse_topk, dim = -1)
360
+ sparse_topk_mask = dots < top_values[..., -1:]
361
+ mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
362
+
363
+ if exists(mask):
364
+ dots = dots.masked_fill(~mask, mask_value)
365
+
366
+ if causal:
367
+ causal_mask = self.create_causal_mask(i, j, device = device)
368
+ dots = dots.masked_fill(causal_mask, mask_value)
369
+
370
+ pre_softmax_attn = dots.clone()
371
+
372
+ attn = self.attn_fn(dots, dim = -1)
373
+ attn = attn.type(dtype)
374
+
375
+ post_softmax_attn = attn.clone()
376
+
377
+ attn = self.attn_dropout(attn)
378
+
379
+ if self.talking_heads:
380
+ attn = self.post_softmax_talking_heads(attn)
381
+
382
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
383
+
384
+ intermediates = Intermediates(
385
+ qk_similarities = qk_similarities,
386
+ pre_softmax_attn = pre_softmax_attn,
387
+ post_softmax_attn = post_softmax_attn
388
+ )
389
+
390
+ return out, intermediates
391
+
392
+ #===================================================================================================================
393
+
394
+ from math import ceil, log
395
+ from typing import Optional, Union, Tuple, Callable
396
+
397
+ import torch
398
+ from torch import nn, Tensor
399
+ from torch.nn import Module
400
+ import torch.nn.functional as F
401
+
402
+ from einops import rearrange, pack, unpack
403
+
404
+ def exists(val):
405
+ return val is not None
406
+
407
+ def default(val, d):
408
+ return val if exists(val) else d
409
+
410
+ def identity(t, *args, **kwargs):
411
+ return t
412
+
413
+ def cast_tuple(t, length = 1):
414
+ return t if isinstance(t, tuple) else (t,) * length
415
+
416
+ def eval_decorator(fn):
417
+ def inner(self, *args, **kwargs):
418
+ was_training = self.training
419
+ self.eval()
420
+ out = fn(self, *args, **kwargs)
421
+ self.train(was_training)
422
+ return out
423
+ return inner
424
+
425
+ # for variable lengthed prefixes
426
+
427
+ def align_right(t, lens, pad_id = 0):
428
+ batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
429
+
430
+ assert lens.ndim == 1 and lens.shape[0] == batch
431
+ assert lens.amax() <= seq_len
432
+
433
+ pad_lens = seq_len - lens
434
+ max_pad_len = pad_lens.amax()
435
+
436
+ batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
437
+ prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
438
+
439
+ t = F.pad(t, (max_pad_len, 0), value = 0)
440
+ offset = max_pad_len - pad_lens
441
+
442
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
443
+ return aligned
444
+
445
+ # nucleus
446
+
447
+ def top_p(logits, thres = 0.9):
448
+ sorted_logits, sorted_indices = torch.sort(logits, descending = True)
449
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
450
+
451
+ sorted_indices_to_remove = cum_probs > thres
452
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
453
+
454
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
455
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
456
+
457
+ # topk
458
+
459
+ def top_k(logits, frac_num_tokens = 0.1, k = None):
460
+ num_tokens = logits.shape[-1]
461
+
462
+ k = default(k, ceil(frac_num_tokens * num_tokens))
463
+ k = min(k, num_tokens)
464
+
465
+ val, ind = torch.topk(logits, k)
466
+ probs = torch.full_like(logits, float('-inf'))
467
+ probs.scatter_(1, ind, val)
468
+ return probs
469
+
470
+ # top_a
471
+
472
+ def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
473
+ probs = F.softmax(logits, dim = -1)
474
+ max_probs = torch.amax(probs, dim = -1, keepdim = True)
475
+ limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
476
+ return torch.where(probs < limit, float('-inf'), logits)
477
+
478
+ # contrastive decoding function
479
+
480
+ def contrastive_decode_fn(
481
+ expert_logits,
482
+ amateur_logits,
483
+ alpha = 0.1,
484
+ beta = 0.5
485
+ ):
486
+ """
487
+ Appendix A Algorithm 2
488
+ https://arxiv.org/abs/2309.09117
489
+ """
490
+
491
+ cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
492
+ diffs = (1 + beta) * expert_logits - beta * amateur_logits
493
+ contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
494
+ return contrastive_decode_logits
495
+
496
+ # autoregressive wrapper class
497
+
498
+ class AutoregressiveWrapper(Module):
499
+ def __init__(
500
+ self,
501
+ net,
502
+ ignore_index = -100,
503
+ pad_value = 0,
504
+ mask_prob = 0.,
505
+ add_attn_z_loss = False,
506
+ return_cache=False
507
+ ):
508
+ super().__init__()
509
+ self.pad_value = pad_value
510
+ self.ignore_index = ignore_index
511
+
512
+ self.net = net
513
+ self.max_seq_len = net.max_seq_len
514
+
515
+ # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
516
+ assert mask_prob < 1.
517
+ self.mask_prob = mask_prob
518
+
519
+ # whether to add router z-loss
520
+ self.add_attn_z_loss = add_attn_z_loss
521
+ self.return_cache = return_cache
522
+
523
+ @torch.inference_mode()
524
+ @eval_decorator
525
+ def generate(
526
+ self,
527
+ prompts,
528
+ seq_len,
529
+ eos_token = None,
530
+ temperature = 1.,
531
+ prompt_lens: Optional[Tensor] = None,
532
+ filter_logits_fn: Callable = top_k,
533
+ restrict_to_max_seq_len = True,
534
+ amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
535
+ filter_kwargs: dict = dict(),
536
+ contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
537
+ beta = 0.5,
538
+ alpha = 0.1
539
+ ),
540
+ cache_kv = True,
541
+ verbose=True,
542
+ return_prime=False,
543
+ **kwargs
544
+ ):
545
+ max_seq_len, device = self.max_seq_len, prompts.device
546
+
547
+ prompts, ps = pack([prompts], '* n')
548
+
549
+ b, t = prompts.shape
550
+
551
+ # handle variable lengthed prompts (prefixes)
552
+
553
+ seq_start_pos = None
554
+ if exists(prompt_lens):
555
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
556
+ seq_start_pos = t - prompt_lens
557
+
558
+ # output from which sampled tokens appended to
559
+
560
+ out = prompts
561
+
562
+ if verbose:
563
+ print("Generating sequence of max length:", seq_len)
564
+
565
+ # kv caches
566
+
567
+ cache = None
568
+
569
+ # if doing contrastive decoding, turn off filter automatically
570
+
571
+ if exists(amateur_model):
572
+ amateur_model = cast_tuple(amateur_model)
573
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
574
+
575
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
576
+
577
+ amateur_caches = [None] * len(amateur_model)
578
+ filter_logits_fn = identity
579
+
580
+ for i, module in enumerate(amateur_model):
581
+ if isinstance(module, AutoregressiveWrapper):
582
+ amateur_model[i] = module.net
583
+
584
+ module.eval()
585
+
586
+ # sampling up to seq_len
587
+
588
+ for sl in range(seq_len):
589
+
590
+ if restrict_to_max_seq_len:
591
+ x = out[:, -max_seq_len:]
592
+
593
+ if exists(cache):
594
+ for inter in cache.attn_intermediates:
595
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
596
+
597
+ logits, new_cache = self.net(
598
+ x,
599
+ return_intermediates = True,
600
+ cache = cache,
601
+ seq_start_pos = seq_start_pos,
602
+ **kwargs
603
+ )
604
+
605
+ if cache_kv and self.net.can_cache_kv:
606
+ cache = new_cache
607
+
608
+ logits = logits[:, -1]
609
+
610
+ # handle contrastive decoding, Li et al.
611
+ # https://arxiv.org/abs/2210.15097
612
+
613
+ if exists(amateur_model):
614
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
615
+ amateur_logits, next_amateur_cache = amateur(
616
+ x,
617
+ return_intermediates = True,
618
+ cache = amateur_cache,
619
+ seq_start_pos = seq_start_pos,
620
+ **kwargs
621
+ )
622
+
623
+ amateur_logits = amateur_logits[:, -1]
624
+
625
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
626
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
627
+
628
+ if cache_kv and amateur.can_cache_kv:
629
+ amateur_caches[i] = next_amateur_cache
630
+
631
+ # filter by top_k, top_p (nucleus), top_a, or custom
632
+
633
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
634
+
635
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
636
+
637
+ sample = torch.multinomial(probs, 1)
638
+
639
+ out = torch.cat((out, sample), dim=-1)
640
+
641
+ if verbose:
642
+ if sl % 32 == 0:
643
+ print(sl, '/', seq_len)
644
+
645
+ if exists(eos_token):
646
+ is_eos_tokens = (out == eos_token)
647
+
648
+ if is_eos_tokens.any(dim = -1).all():
649
+ # mask out everything after the eos tokens
650
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
651
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
652
+ out = out.masked_fill(mask, self.pad_value)
653
+
654
+ if verbose:
655
+ print('Model called the end of sequence at:', sl, '/', seq_len)
656
+
657
+ break
658
+
659
+ if return_prime:
660
+ return out[:, :]
661
+
662
+ else:
663
+ return out[:, t:]
664
+
665
+ # out, = unpack(out, ps, '* n')
666
+
667
+ # return out
668
+
669
+ def compute_accuracy(self, logits, labels):
670
+ out = torch.argmax(logits, dim=-1)
671
+ out = out.flatten()
672
+ labels = labels.flatten()
673
+
674
+ mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
675
+ out = out[mask]
676
+ labels = labels[mask]
677
+
678
+ num_right = (out == labels)
679
+ num_right = torch.sum(num_right).type(torch.float32)
680
+
681
+ acc = num_right / len(labels)
682
+ return acc
683
+
684
+ def forward(self, x, **kwargs):
685
+ seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
686
+
687
+ inp, target = x[:, :-1], x[:, 1:]
688
+ inp = torch.where(inp == ignore_index, self.pad_value, inp)
689
+
690
+ if self.mask_prob > 0.:
691
+ rand = torch.randn(inp.shape, device = x.device)
692
+ rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
693
+ num_mask = min(int(seq * self.mask_prob), seq - 1)
694
+ indices = rand.topk(num_mask, dim = -1).indices
695
+ mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
696
+ kwargs.update(self_attn_kv_mask = mask)
697
+
698
+ logits, cache = self.net(
699
+ inp,
700
+ return_intermediates = True,
701
+ return_attn_z_loss = add_attn_z_loss,
702
+ **kwargs
703
+ )
704
+
705
+ acc = self.compute_accuracy(logits, target)
706
+
707
+ loss = F.cross_entropy(
708
+ rearrange(logits, 'b n c -> b c n'),
709
+ target,
710
+ ignore_index = ignore_index
711
+ )
712
+
713
+ if add_attn_z_loss:
714
+ loss = loss + cache.attn_z_loss
715
+
716
+ if self.return_cache:
717
+ return loss, acc, cache
718
+
719
+ else:
720
+ return loss, acc
721
+
722
+ #===============================================================================
723
+
724
+ import math
725
+ from random import random
726
+
727
+ import torch
728
+ from torch import nn, einsum, Tensor
729
+ import torch.nn.functional as F
730
+
731
+ from functools import partial, wraps
732
+ from inspect import isfunction
733
+ from collections import namedtuple
734
+ from dataclasses import dataclass
735
+ from typing import List, Callable, Optional
736
+
737
+ from einops import rearrange, repeat, reduce, pack, unpack
738
+ from einops.layers.torch import Rearrange
739
+
740
+ # constants
741
+
742
+ DEFAULT_DIM_HEAD = 64
743
+
744
+ @dataclass
745
+ class LayerIntermediates:
746
+ hiddens: Optional[List[Tensor]] = None
747
+ attn_intermediates: Optional[List[Intermediates]] = None
748
+ layer_hiddens: Optional[List[Tensor]] = None
749
+ attn_z_loss: Optional[Tensor] = None
750
+ mems: Optional[Tensor] = None
751
+
752
+ # helpers
753
+
754
+ def exists(val):
755
+ return val is not None
756
+
757
+ def default(val, d):
758
+ if exists(val):
759
+ return val
760
+ return d() if isfunction(d) else d
761
+
762
+ def cast_tuple(val, depth):
763
+ return val if isinstance(val, tuple) else (val,) * depth
764
+
765
+ def divisible_by(num, den):
766
+ return (num % den) == 0
767
+
768
+ def maybe(fn):
769
+ @wraps(fn)
770
+ def inner(x, *args, **kwargs):
771
+ if not exists(x):
772
+ return x
773
+ return fn(x, *args, **kwargs)
774
+ return inner
775
+
776
+ class always():
777
+ def __init__(self, val):
778
+ self.val = val
779
+ def __call__(self, *args, **kwargs):
780
+ return self.val
781
+
782
+ class not_equals():
783
+ def __init__(self, val):
784
+ self.val = val
785
+ def __call__(self, x, *args, **kwargs):
786
+ return x != self.val
787
+
788
+ class equals():
789
+ def __init__(self, val):
790
+ self.val = val
791
+ def __call__(self, x, *args, **kwargs):
792
+ return x == self.val
793
+
794
+ def Sequential(*modules):
795
+ return nn.Sequential(*filter(exists, modules))
796
+
797
+ # tensor helpers
798
+
799
+ def max_neg_value(tensor):
800
+ return -torch.finfo(tensor.dtype).max
801
+
802
+ def l2norm(t, groups = 1):
803
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
804
+ t = F.normalize(t, p = 2, dim = -1)
805
+ return rearrange(t, '... g d -> ... (g d)')
806
+
807
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
808
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
809
+ zeros = ((0, 0) * dims_from_right)
810
+ return F.pad(t, (*zeros, *pad), value = value)
811
+
812
+ def or_reduce(masks):
813
+ head, *body = masks
814
+ for rest in body:
815
+ head = head | rest
816
+ return head
817
+
818
+ # auxiliary loss helpers
819
+
820
+ def calc_z_loss(
821
+ pre_softmax_attns: List[Tensor],
822
+ mask = None,
823
+ weight = 1.
824
+ ):
825
+ # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
826
+ # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
827
+ # also used in PaLM as one of the measures
828
+
829
+ lse = 0.
830
+
831
+ for attn in pre_softmax_attns:
832
+ lse = lse + attn.logsumexp(dim = -1)
833
+
834
+ loss = torch.square(lse)
835
+ loss = reduce(loss, 'b h n -> b n', 'sum')
836
+
837
+ if not exists(mask):
838
+ return loss.mean() * weight
839
+
840
+ loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
841
+ return loss * weight
842
+
843
+ # init helpers
844
+
845
+ def init_zero_(layer):
846
+ nn.init.constant_(layer.weight, 0.)
847
+ if exists(layer.bias):
848
+ nn.init.constant_(layer.bias, 0.)
849
+
850
+ # keyword argument helpers
851
+
852
+ def pick_and_pop(keys, d):
853
+ values = list(map(lambda key: d.pop(key), keys))
854
+ return dict(zip(keys, values))
855
+
856
+ def group_dict_by_key(cond, d):
857
+ return_val = [dict(),dict()]
858
+ for key in d.keys():
859
+ match = bool(cond(key))
860
+ ind = int(not match)
861
+ return_val[ind][key] = d[key]
862
+ return (*return_val,)
863
+
864
+ def string_begins_with(prefix, str):
865
+ return str.startswith(prefix)
866
+
867
+ def group_by_key_prefix(prefix, d):
868
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
869
+
870
+ def groupby_prefix_and_trim(prefix, d):
871
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
872
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
873
+ return kwargs_without_prefix, kwargs
874
+
875
+ # structured dropout, more effective than traditional attention dropouts
876
+
877
+ def dropout_seq(seq, mask, dropout):
878
+ b, n, *_, device = *seq.shape, seq.device
879
+ logits = torch.randn(b, n, device = device)
880
+
881
+ if exists(mask):
882
+ mask_value = max_neg_value(logits)
883
+ logits = logits.masked_fill(~mask, mask_value)
884
+
885
+ keep_prob = 1. - dropout
886
+ num_keep = max(1, int(keep_prob * n))
887
+ keep_indices = logits.topk(num_keep, dim = 1).indices
888
+
889
+ batch_indices = torch.arange(b, device = device)
890
+ batch_indices = rearrange(batch_indices, 'b -> b 1')
891
+
892
+ seq = seq[batch_indices, keep_indices]
893
+
894
+ if exists(mask):
895
+ seq_counts = mask.sum(dim = -1)
896
+ seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
897
+ keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
898
+
899
+ mask = mask[batch_indices, keep_indices] & keep_mask
900
+
901
+ return seq, mask
902
+
903
+ # activations
904
+
905
+ class ReluSquared(nn.Module):
906
+ def forward(self, x):
907
+ return F.relu(x) ** 2
908
+
909
+ # embedding
910
+
911
+ class TokenEmbedding(nn.Module):
912
+ def __init__(self, dim, num_tokens, l2norm_embed = False):
913
+ super().__init__()
914
+ self.l2norm_embed = l2norm_embed
915
+ self.emb = nn.Embedding(num_tokens, dim)
916
+
917
+ def forward(self, x):
918
+ token_emb = self.emb(x)
919
+ return l2norm(token_emb) if self.l2norm_embed else token_emb
920
+
921
+ # positional embeddings
922
+
923
+ class AbsolutePositionalEmbedding(nn.Module):
924
+ def __init__(self, dim, max_seq_len, l2norm_embed = False):
925
+ super().__init__()
926
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.
927
+ self.max_seq_len = max_seq_len
928
+ self.l2norm_embed = l2norm_embed
929
+ self.emb = nn.Embedding(max_seq_len, dim)
930
+
931
+ def forward(self, x, pos = None, seq_start_pos = None):
932
+ seq_len, device = x.shape[1], x.device
933
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
934
+
935
+ if not exists(pos):
936
+ pos = torch.arange(seq_len, device = device)
937
+
938
+ if exists(seq_start_pos):
939
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
940
+
941
+ pos_emb = self.emb(pos)
942
+ pos_emb = pos_emb * self.scale
943
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb
944
+
945
+ class ScaledSinusoidalEmbedding(nn.Module):
946
+ def __init__(self, dim, theta = 10000):
947
+ super().__init__()
948
+ assert divisible_by(dim, 2)
949
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
950
+
951
+ half_dim = dim // 2
952
+ freq_seq = torch.arange(half_dim).float() / half_dim
953
+ inv_freq = theta ** -freq_seq
954
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
955
+
956
+ def forward(self, x, pos = None, seq_start_pos = None):
957
+ seq_len, device = x.shape[1], x.device
958
+
959
+ if not exists(pos):
960
+ pos = torch.arange(seq_len, device = device)
961
+
962
+ if exists(seq_start_pos):
963
+ pos = pos - seq_start_pos[..., None]
964
+
965
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
966
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
967
+ return emb * self.scale
968
+
969
+ class RelativePositionBias(nn.Module):
970
+ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
971
+ super().__init__()
972
+ self.scale = scale
973
+ self.causal = causal
974
+ self.num_buckets = num_buckets
975
+ self.max_distance = max_distance
976
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
977
+
978
+ @staticmethod
979
+ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
980
+ ret = 0
981
+ n = -relative_position
982
+ if not causal:
983
+ num_buckets //= 2
984
+ ret += (n < 0).long() * num_buckets
985
+ n = torch.abs(n)
986
+ else:
987
+ n = torch.max(n, torch.zeros_like(n))
988
+
989
+ max_exact = num_buckets // 2
990
+ is_small = n < max_exact
991
+
992
+ val_if_large = max_exact + (
993
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
994
+ ).long()
995
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
996
+
997
+ ret += torch.where(is_small, n, val_if_large)
998
+ return ret
999
+
1000
+ @property
1001
+ def device(self):
1002
+ return next(self.parameters()).device
1003
+
1004
+ def forward(self, i, j):
1005
+ device = self.device
1006
+ q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
1007
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
1008
+ rel_pos = k_pos[None, :] - q_pos[:, None]
1009
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
1010
+ values = self.relative_attention_bias(rp_bucket)
1011
+ bias = rearrange(values, 'i j h -> h i j')
1012
+ return bias * self.scale
1013
+
1014
+ class DynamicPositionBias(nn.Module):
1015
+ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
1016
+ super().__init__()
1017
+ assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
1018
+ self.log_distance = log_distance
1019
+
1020
+ self.mlp = nn.ModuleList([])
1021
+
1022
+ self.mlp.append(Sequential(
1023
+ nn.Linear(1, dim),
1024
+ nn.LayerNorm(dim) if norm else None,
1025
+ nn.SiLU()
1026
+ ))
1027
+
1028
+ for _ in range(depth - 1):
1029
+ self.mlp.append(Sequential(
1030
+ nn.Linear(dim, dim),
1031
+ nn.LayerNorm(dim) if norm else None,
1032
+ nn.SiLU()
1033
+ ))
1034
+
1035
+ self.mlp.append(nn.Linear(dim, heads))
1036
+
1037
+ @property
1038
+ def device(self):
1039
+ return next(self.parameters()).device
1040
+
1041
+ def forward(self, i, j):
1042
+ assert i == j
1043
+ n, device = j, self.device
1044
+
1045
+ # get the (n x n) matrix of distances
1046
+ seq_arange = torch.arange(n, device = device)
1047
+ context_arange = torch.arange(n, device = device)
1048
+ indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
1049
+ indices += (n - 1)
1050
+
1051
+ # input to continuous positions MLP
1052
+ pos = torch.arange(-n + 1, n, device = device).float()
1053
+ pos = rearrange(pos, '... -> ... 1')
1054
+
1055
+ if self.log_distance:
1056
+ pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
1057
+
1058
+ for layer in self.mlp:
1059
+ pos = layer(pos)
1060
+
1061
+ # get position biases
1062
+ bias = pos[indices]
1063
+ bias = rearrange(bias, 'i j h -> h i j')
1064
+ return bias
1065
+
1066
+ class AlibiPositionalBias(nn.Module):
1067
+ def __init__(self, heads, total_heads, **kwargs):
1068
+ super().__init__()
1069
+ self.heads = heads
1070
+ self.total_heads = total_heads
1071
+
1072
+ slopes = Tensor(self._get_slopes(heads))
1073
+ slopes = rearrange(slopes, 'h -> h 1 1')
1074
+ self.register_buffer('slopes', slopes, persistent = False)
1075
+ self.register_buffer('bias', None, persistent = False)
1076
+
1077
+ def get_bias(self, i, j, device):
1078
+ i_arange = torch.arange(j - i, j, device = device)
1079
+ j_arange = torch.arange(j, device = device)
1080
+ bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
1081
+ return bias
1082
+
1083
+ @staticmethod
1084
+ def _get_slopes(heads):
1085
+ def get_slopes_power_of_2(n):
1086
+ start = (2**(-2**-(math.log2(n)-3)))
1087
+ ratio = start
1088
+ return [start*ratio**i for i in range(n)]
1089
+
1090
+ if math.log2(heads).is_integer():
1091
+ return get_slopes_power_of_2(heads)
1092
+
1093
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
1094
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
1095
+
1096
+ @property
1097
+ def device(self):
1098
+ return next(self.buffers()).device
1099
+
1100
+ def forward(self, i, j):
1101
+ h, device = self.total_heads, self.device
1102
+
1103
+ if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
1104
+ return self.bias[..., -i:, -j:]
1105
+
1106
+ bias = self.get_bias(i, j, device)
1107
+ bias = bias * self.slopes
1108
+
1109
+ num_heads_unalibied = h - bias.shape[0]
1110
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
1111
+ self.register_buffer('bias', bias, persistent = False)
1112
+
1113
+ return self.bias
1114
+
1115
+ class RotaryEmbedding(nn.Module):
1116
+ def __init__(
1117
+ self,
1118
+ dim,
1119
+ use_xpos = False,
1120
+ scale_base = 512,
1121
+ interpolation_factor = 1.,
1122
+ base = 10000,
1123
+ base_rescale_factor = 1.
1124
+ ):
1125
+ super().__init__()
1126
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
1127
+ # has some connection to NTK literature
1128
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
1129
+ base *= base_rescale_factor ** (dim / (dim - 2))
1130
+
1131
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
1132
+ self.register_buffer('inv_freq', inv_freq)
1133
+
1134
+ assert interpolation_factor >= 1.
1135
+ self.interpolation_factor = interpolation_factor
1136
+
1137
+ if not use_xpos:
1138
+ self.register_buffer('scale', None)
1139
+ return
1140
+
1141
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
1142
+
1143
+ self.scale_base = scale_base
1144
+ self.register_buffer('scale', scale)
1145
+
1146
+ def forward(self, seq_len):
1147
+ device = self.inv_freq.device
1148
+ t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
1149
+
1150
+ t = t / self.interpolation_factor
1151
+
1152
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
1153
+ freqs = torch.cat((freqs, freqs), dim = -1)
1154
+
1155
+ if not exists(self.scale):
1156
+ return freqs, 1.
1157
+
1158
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
1159
+ scale = self.scale ** rearrange(power, 'n -> n 1')
1160
+ scale = torch.cat((scale, scale), dim = -1)
1161
+
1162
+ return freqs, scale
1163
+
1164
+
1165
+ def rotate_half(x):
1166
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
1167
+ x1, x2 = x.unbind(dim = -2)
1168
+ return torch.cat((-x2, x1), dim = -1)
1169
+
1170
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
1171
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
1172
+ freqs = freqs[-seq_len:, :]
1173
+
1174
+ if t.ndim == 4 and freqs.ndim == 3:
1175
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
1176
+
1177
+ # partial rotary embeddings, Wang et al. GPT-J
1178
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
1179
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
1180
+ return torch.cat((t, t_unrotated), dim = -1)
1181
+
1182
+ # norms
1183
+
1184
+ class Scale(nn.Module):
1185
+ def __init__(self, value, fn):
1186
+ super().__init__()
1187
+ self.value = value
1188
+ self.fn = fn
1189
+
1190
+ def forward(self, x, **kwargs):
1191
+ out = self.fn(x, **kwargs)
1192
+ scale_fn = lambda t: t * self.value
1193
+
1194
+ if not isinstance(out, tuple):
1195
+ return scale_fn(out)
1196
+
1197
+ return (scale_fn(out[0]), *out[1:])
1198
+
1199
+ class ScaleNorm(nn.Module):
1200
+ def __init__(self, dim, eps = 1e-5):
1201
+ super().__init__()
1202
+ self.eps = eps
1203
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
1204
+
1205
+ def forward(self, x):
1206
+ norm = torch.norm(x, dim = -1, keepdim = True)
1207
+ return x / norm.clamp(min = self.eps) * self.g
1208
+
1209
+ class RMSNorm(nn.Module):
1210
+ def __init__(self, dim):
1211
+ super().__init__()
1212
+ self.scale = dim ** 0.5
1213
+ self.g = nn.Parameter(torch.ones(dim))
1214
+
1215
+ def forward(self, x):
1216
+ return F.normalize(x, dim = -1) * self.scale * self.g
1217
+
1218
+ class SimpleRMSNorm(nn.Module):
1219
+ def __init__(self, dim):
1220
+ super().__init__()
1221
+ self.scale = dim ** 0.5
1222
+
1223
+ def forward(self, x):
1224
+ return F.normalize(x, dim = -1) * self.scale
1225
+
1226
+ # residual and residual gates
1227
+
1228
+ class Residual(nn.Module):
1229
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
1230
+ super().__init__()
1231
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1232
+ self.scale_residual_constant = scale_residual_constant
1233
+
1234
+ def forward(self, x, residual):
1235
+ if exists(self.residual_scale):
1236
+ residual = residual * self.residual_scale
1237
+
1238
+ if self.scale_residual_constant != 1:
1239
+ residual = residual * self.scale_residual_constant
1240
+
1241
+ return x + residual
1242
+
1243
+ class GRUGating(nn.Module):
1244
+ def __init__(self, dim, scale_residual = False, **kwargs):
1245
+ super().__init__()
1246
+ self.gru = nn.GRUCell(dim, dim)
1247
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1248
+
1249
+ def forward(self, x, residual):
1250
+ if exists(self.residual_scale):
1251
+ residual = residual * self.residual_scale
1252
+
1253
+ gated_output = self.gru(
1254
+ rearrange(x, 'b n d -> (b n) d'),
1255
+ rearrange(residual, 'b n d -> (b n) d')
1256
+ )
1257
+
1258
+ return gated_output.reshape_as(x)
1259
+
1260
+ # token shifting
1261
+
1262
+ def shift(t, amount, mask = None):
1263
+ if amount == 0:
1264
+ return t
1265
+ else:
1266
+ amount = min(amount, t.shape[1])
1267
+
1268
+ if exists(mask):
1269
+ t = t.masked_fill(~mask[..., None], 0.)
1270
+
1271
+ return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
1272
+
1273
+ class ShiftTokens(nn.Module):
1274
+ def __init__(self, shifts, fn):
1275
+ super().__init__()
1276
+ self.fn = fn
1277
+ self.shifts = tuple(shifts)
1278
+
1279
+ def forward(self, x, **kwargs):
1280
+ mask = kwargs.get('mask', None)
1281
+ shifts = self.shifts
1282
+ segments = len(shifts)
1283
+ feats_per_shift = x.shape[-1] // segments
1284
+ splitted = x.split(feats_per_shift, dim = -1)
1285
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
1286
+ segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
1287
+ x = torch.cat((*segments_to_shift, *rest), dim = -1)
1288
+ return self.fn(x, **kwargs)
1289
+
1290
+ # feedforward
1291
+
1292
+ class GLU(nn.Module):
1293
+ def __init__(
1294
+ self,
1295
+ dim_in,
1296
+ dim_out,
1297
+ activation: Callable,
1298
+ mult_bias = False
1299
+ ):
1300
+ super().__init__()
1301
+ self.act = activation
1302
+ self.proj = nn.Linear(dim_in, dim_out * 2)
1303
+ self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
1304
+
1305
+ def forward(self, x):
1306
+ x, gate = self.proj(x).chunk(2, dim = -1)
1307
+ return x * self.act(gate) * self.mult_bias
1308
+
1309
+ class FeedForward(nn.Module):
1310
+ def __init__(
1311
+ self,
1312
+ dim,
1313
+ dim_out = None,
1314
+ mult = 4,
1315
+ glu = False,
1316
+ glu_mult_bias = False,
1317
+ swish = False,
1318
+ relu_squared = False,
1319
+ post_act_ln = False,
1320
+ dropout = 0.,
1321
+ no_bias = False,
1322
+ zero_init_output = False
1323
+ ):
1324
+ super().__init__()
1325
+ inner_dim = int(dim * mult)
1326
+ dim_out = default(dim_out, dim)
1327
+
1328
+ if relu_squared:
1329
+ activation = ReluSquared()
1330
+ elif swish:
1331
+ activation = nn.SiLU()
1332
+ else:
1333
+ activation = nn.GELU()
1334
+
1335
+ if glu:
1336
+ project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1337
+ else:
1338
+ project_in = nn.Sequential(
1339
+ nn.Linear(dim, inner_dim, bias = not no_bias),
1340
+ activation
1341
+ )
1342
+
1343
+ self.ff = Sequential(
1344
+ project_in,
1345
+ nn.LayerNorm(inner_dim) if post_act_ln else None,
1346
+ nn.Dropout(dropout),
1347
+ nn.Linear(inner_dim, dim_out, bias = not no_bias)
1348
+ )
1349
+
1350
+ # init last linear layer to 0
1351
+ if zero_init_output:
1352
+ init_zero_(self.ff[-1])
1353
+
1354
+ def forward(self, x):
1355
+ return self.ff(x)
1356
+
1357
+ # attention. it is all we need
1358
+
1359
+ class Attention(nn.Module):
1360
+ def __init__(
1361
+ self,
1362
+ dim,
1363
+ dim_head = DEFAULT_DIM_HEAD,
1364
+ heads = 8,
1365
+ causal = False,
1366
+ flash = False,
1367
+ talking_heads = False,
1368
+ head_scale = False,
1369
+ sparse_topk = None,
1370
+ num_mem_kv = 0,
1371
+ dropout = 0.,
1372
+ on_attn = False,
1373
+ gate_value_heads = False,
1374
+ gate_values = False,
1375
+ zero_init_output = False,
1376
+ max_attend_past = None,
1377
+ qk_norm = False,
1378
+ qk_norm_groups = 1,
1379
+ qk_norm_scale = 10,
1380
+ qk_norm_dim_scale = False,
1381
+ one_kv_head = False,
1382
+ kv_heads = None,
1383
+ shared_kv = False,
1384
+ value_dim_head = None,
1385
+ tensor_product = False, # https://arxiv.org/abs/2208.06061
1386
+ add_zero_kv = False, # same as add_zero_attn in pytorch
1387
+ rotary_embed_values = False,
1388
+ onnxable = False
1389
+ ):
1390
+ super().__init__()
1391
+ self.scale = dim_head ** -0.5
1392
+
1393
+ self.heads = heads
1394
+ self.causal = causal
1395
+ self.max_attend_past = max_attend_past
1396
+
1397
+ assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
1398
+
1399
+ value_dim_head = default(value_dim_head, dim_head)
1400
+ kv_heads = default(kv_heads, heads)
1401
+
1402
+ kv_heads = 1 if one_kv_head else kv_heads
1403
+ assert divisible_by(heads, kv_heads)
1404
+
1405
+ self.kv_heads = kv_heads
1406
+
1407
+ q_dim = dim_head * heads
1408
+ k_dim = dim_head * kv_heads
1409
+ v_dim = value_dim_head * kv_heads
1410
+ out_dim = value_dim_head * heads
1411
+
1412
+ self.to_q = nn.Linear(dim, q_dim, bias = False)
1413
+ self.to_k = nn.Linear(dim, k_dim, bias = False)
1414
+
1415
+ # shared key / values, for further memory savings during inference
1416
+ assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1417
+ self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
1418
+
1419
+ # relations projection from tp-attention
1420
+ self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
1421
+
1422
+ # add GLU gating for aggregated values, from alphafold2
1423
+ self.to_v_gate = None
1424
+ if gate_values:
1425
+ self.to_v_gate = nn.Linear(dim, out_dim)
1426
+ nn.init.constant_(self.to_v_gate.weight, 0)
1427
+ nn.init.constant_(self.to_v_gate.bias, 10)
1428
+
1429
+ # add per head gating of the output values, from 'Attend to nothing' paper
1430
+ self.to_v_head_gate = None
1431
+ if gate_value_heads:
1432
+ self.to_v_head_gate = nn.Linear(dim, heads)
1433
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
1434
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
1435
+
1436
+ # cosine sim attention
1437
+ self.qk_norm = qk_norm
1438
+ self.qk_norm_groups = qk_norm_groups
1439
+ self.qk_norm_scale = qk_norm_scale
1440
+
1441
+ # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
1442
+ self.qk_norm_dim_scale = qk_norm_dim_scale
1443
+
1444
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1445
+ if qk_norm and qk_norm_dim_scale:
1446
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1447
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1448
+
1449
+ assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1450
+ assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1451
+
1452
+ # attend class - includes core attention algorithm + talking heads
1453
+
1454
+ self.attend = Attend(
1455
+ heads = heads,
1456
+ causal = causal,
1457
+ talking_heads = talking_heads,
1458
+ dropout = dropout,
1459
+ sparse_topk = sparse_topk,
1460
+ qk_norm = qk_norm,
1461
+ scale = qk_norm_scale if qk_norm else self.scale,
1462
+ add_zero_kv = add_zero_kv,
1463
+ flash = flash,
1464
+ onnxable = onnxable
1465
+ )
1466
+
1467
+ # head scaling
1468
+ self.head_scale = head_scale
1469
+ if head_scale:
1470
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
1471
+
1472
+ # explicit topk sparse attention
1473
+ self.sparse_topk = sparse_topk
1474
+
1475
+ # add memory key / values
1476
+ self.num_mem_kv = num_mem_kv
1477
+ if num_mem_kv > 0:
1478
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1479
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1480
+
1481
+ # attention on attention
1482
+ self.attn_on_attn = on_attn
1483
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1484
+
1485
+ # whether to rotate positions into values, for absolute positions in addition to relative
1486
+ self.rotary_embed_values = rotary_embed_values
1487
+
1488
+ # init output projection 0
1489
+ if zero_init_output:
1490
+ init_zero_(self.to_out)
1491
+
1492
+ def forward(
1493
+ self,
1494
+ x,
1495
+ context = None,
1496
+ mask = None,
1497
+ context_mask = None,
1498
+ attn_mask = None,
1499
+ rel_pos = None,
1500
+ rotary_pos_emb = None,
1501
+ prev_attn = None,
1502
+ mem = None,
1503
+ return_intermediates = False,
1504
+ cache: Optional[Intermediates] = None,
1505
+ ):
1506
+ b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
1507
+ kv_input = default(context, x)
1508
+
1509
+ q_input = x
1510
+ k_input = kv_input
1511
+ v_input = kv_input
1512
+ r_input = x
1513
+
1514
+ if exists(mem):
1515
+ k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1516
+ v_input, _ = pack([mem, v_input], 'b * d')
1517
+
1518
+ q = self.to_q(q_input)
1519
+ k = self.to_k(k_input)
1520
+ v = self.to_v(v_input) if exists(self.to_v) else k
1521
+ r = self.to_r(r_input) if exists(self.to_r) else None
1522
+
1523
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1524
+
1525
+ k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1526
+
1527
+ if exists(cache) and not has_context:
1528
+ ck, cv = cache.cached_kv
1529
+
1530
+ if exists(mem):
1531
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1532
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1533
+
1534
+ k = torch.cat((ck, k), dim = -2)
1535
+ v = torch.cat((cv, v), dim = -2)
1536
+
1537
+ if exists(mem):
1538
+ k = torch.cat((mk, k), dim = -2)
1539
+ v = torch.cat((mv, v), dim = -2)
1540
+
1541
+ if return_intermediates:
1542
+ mem_len = mem.shape[-2] if exists(mem) else 0
1543
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1544
+
1545
+ if self.qk_norm:
1546
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1547
+ q, k = map(qk_l2norm, (q, k))
1548
+ scale = self.qk_norm_scale
1549
+
1550
+ q = q * self.qk_norm_q_scale
1551
+ k = k * self.qk_norm_k_scale
1552
+
1553
+ if exists(rotary_pos_emb) and not has_context:
1554
+ freqs, xpos_scale = rotary_pos_emb
1555
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1556
+
1557
+ q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1558
+ k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1559
+
1560
+ if self.rotary_embed_values:
1561
+ v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1562
+
1563
+ input_mask = context_mask
1564
+
1565
+ if not exists(input_mask) and not has_context:
1566
+ input_mask = mask
1567
+
1568
+ if self.num_mem_kv > 0:
1569
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1570
+
1571
+ if self.qk_norm:
1572
+ mem_k = l2norm(mem_k)
1573
+ mem_k = mem_k * self.qk_norm_k_scale
1574
+
1575
+ k = torch.cat((mem_k, k), dim = -2)
1576
+ v = torch.cat((mem_v, v), dim = -2)
1577
+
1578
+ if exists(input_mask):
1579
+ input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1580
+
1581
+ i, j = map(lambda t: t.shape[-2], (q, k))
1582
+
1583
+ # determine masking
1584
+
1585
+ mask_value = max_neg_value(q)
1586
+ masks = []
1587
+ final_attn_mask = None
1588
+
1589
+ if exists(input_mask):
1590
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1591
+ masks.append(~input_mask)
1592
+
1593
+ if exists(attn_mask):
1594
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1595
+ if attn_mask.ndim == 2:
1596
+ attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1597
+ elif attn_mask.ndim == 3:
1598
+ attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1599
+ masks.append(~attn_mask)
1600
+
1601
+ if exists(self.max_attend_past):
1602
+ range_q = torch.arange(j - i, j, device = device)
1603
+ range_k = torch.arange(j, device = device)
1604
+ dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1605
+ max_attend_past_mask = dist > self.max_attend_past
1606
+ masks.append(max_attend_past_mask)
1607
+
1608
+ if len(masks) > 0:
1609
+ final_attn_mask = ~or_reduce(masks)
1610
+
1611
+ # prepare relative positional bias, if needed
1612
+
1613
+ attn_bias = None
1614
+ if exists(rel_pos):
1615
+ attn_bias = rel_pos(i, j)
1616
+
1617
+ # attention is all we need
1618
+
1619
+ out, intermediates = self.attend(
1620
+ q, k, v,
1621
+ mask = final_attn_mask,
1622
+ attn_bias = attn_bias,
1623
+ prev_attn = prev_attn
1624
+ )
1625
+
1626
+ # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1627
+
1628
+ if exists(r):
1629
+ out = out * r + out
1630
+
1631
+ # normformer scaling of heads
1632
+
1633
+ if head_scale:
1634
+ out = out * self.head_scale_params
1635
+
1636
+ # per head gating, from https://arxiv.org/abs/2306.12929
1637
+
1638
+ if exists(self.to_v_head_gate):
1639
+ head_gate = self.to_v_head_gate(x)
1640
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1641
+
1642
+ # merge heads
1643
+
1644
+ out = rearrange(out, 'b h n d -> b n (h d)')
1645
+
1646
+ # alphafold2 styled gating of the values
1647
+
1648
+ if exists(self.to_v_gate):
1649
+ gates = self.to_v_gate(x)
1650
+ out = out * gates.sigmoid()
1651
+
1652
+ # combine the heads
1653
+
1654
+ out = self.to_out(out)
1655
+
1656
+ if exists(mask):
1657
+ mask = rearrange(mask, 'b n -> b n 1')
1658
+ out = out.masked_fill(~mask, 0.)
1659
+
1660
+ if not return_intermediates:
1661
+ return out
1662
+
1663
+ intermediates.cached_kv = cached_kv
1664
+
1665
+ return out, intermediates
1666
+
1667
+ class AttentionLayers(nn.Module):
1668
+ def __init__(
1669
+ self,
1670
+ dim,
1671
+ depth,
1672
+ heads = 8,
1673
+ causal = False,
1674
+ cross_attend = False,
1675
+ only_cross = False,
1676
+ use_scalenorm = False,
1677
+ use_rmsnorm = False,
1678
+ use_simple_rmsnorm = False,
1679
+ alibi_pos_bias = False,
1680
+ alibi_num_heads = None,
1681
+ rel_pos_bias = False,
1682
+ rel_pos_num_buckets = 32,
1683
+ rel_pos_max_distance = 128,
1684
+ dynamic_pos_bias = False,
1685
+ dynamic_pos_bias_log_distance = False,
1686
+ dynamic_pos_bias_mlp_depth = 2,
1687
+ dynamic_pos_bias_norm = False,
1688
+ rotary_pos_emb = False,
1689
+ rotary_emb_dim = None,
1690
+ rotary_xpos = False,
1691
+ rotary_interpolation_factor = 1.,
1692
+ rotary_xpos_scale_base = 512,
1693
+ rotary_base_rescale_factor = 1.,
1694
+ custom_layers = None,
1695
+ sandwich_coef = None,
1696
+ par_ratio = None,
1697
+ weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1698
+ layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1699
+ residual_attn = False,
1700
+ cross_residual_attn = False,
1701
+ macaron = False,
1702
+ pre_norm = True,
1703
+ pre_norm_has_final_norm = True,
1704
+ gate_residual = False,
1705
+ scale_residual = False,
1706
+ scale_residual_constant = 1.,
1707
+ shift_tokens = 0,
1708
+ sandwich_norm = False,
1709
+ resi_dual = False,
1710
+ resi_dual_scale = 1.,
1711
+ zero_init_branch_output = False,
1712
+ layer_dropout = 0.,
1713
+ cross_attn_tokens_dropout = 0.,
1714
+ **kwargs
1715
+ ):
1716
+ super().__init__()
1717
+ rotary_pos_emb = rotary_pos_emb or rotary_xpos
1718
+
1719
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1720
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1721
+
1722
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1723
+
1724
+ self.dim = dim
1725
+ self.depth = depth
1726
+ self.causal = causal
1727
+ self.layers = nn.ModuleList([])
1728
+
1729
+ self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1730
+
1731
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1732
+
1733
+ assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1734
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1735
+
1736
+ assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1737
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1738
+
1739
+ # relative positional bias
1740
+
1741
+ flash_attn = attn_kwargs.get('flash', False)
1742
+ assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1743
+
1744
+ self.rel_pos = None
1745
+ if rel_pos_bias:
1746
+ assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1747
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1748
+ elif dynamic_pos_bias:
1749
+ assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1750
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1751
+ elif alibi_pos_bias:
1752
+ alibi_num_heads = default(alibi_num_heads, heads)
1753
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1754
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1755
+
1756
+ assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1757
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1758
+
1759
+ if resi_dual:
1760
+ pre_norm = False
1761
+
1762
+ self.pre_norm = pre_norm
1763
+ self.sandwich_norm = sandwich_norm
1764
+
1765
+ self.resi_dual = resi_dual
1766
+ assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1767
+ self.resi_dual_scale = resi_dual_scale
1768
+
1769
+ self.residual_attn = residual_attn
1770
+ self.cross_residual_attn = cross_residual_attn
1771
+ assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
1772
+
1773
+ self.cross_attend = cross_attend
1774
+
1775
+ assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1776
+
1777
+ if use_scalenorm:
1778
+ norm_class = ScaleNorm
1779
+ elif use_rmsnorm:
1780
+ norm_class = RMSNorm
1781
+ elif use_simple_rmsnorm:
1782
+ norm_class = SimpleRMSNorm
1783
+ else:
1784
+ norm_class = nn.LayerNorm
1785
+
1786
+ norm_fn = partial(norm_class, dim)
1787
+
1788
+ if cross_attend and not only_cross:
1789
+ default_block = ('a', 'c', 'f')
1790
+ elif cross_attend and only_cross:
1791
+ default_block = ('c', 'f')
1792
+ else:
1793
+ default_block = ('a', 'f')
1794
+
1795
+ if macaron:
1796
+ default_block = ('f',) + default_block
1797
+
1798
+ # zero init
1799
+
1800
+ if zero_init_branch_output:
1801
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1802
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1803
+
1804
+ # setup weight tying, which is a special case of `layer_execute_order`
1805
+
1806
+ assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1807
+
1808
+ if weight_tie_layers:
1809
+ assert not exists(layers_execute_order)
1810
+ layers_execute_order = tuple(range(len(default_block))) * depth
1811
+ depth = 1
1812
+
1813
+ # calculate layer block order
1814
+
1815
+ if exists(custom_layers):
1816
+ layer_types = custom_layers
1817
+ elif exists(par_ratio):
1818
+ par_depth = depth * len(default_block)
1819
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1820
+ default_block = tuple(filter(not_equals('f'), default_block))
1821
+ par_attn = par_depth // par_ratio
1822
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1823
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
1824
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1825
+ par_block = default_block + ('f',) * (par_width - len(default_block))
1826
+ par_head = par_block * par_attn
1827
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
1828
+ elif exists(sandwich_coef):
1829
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1830
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1831
+ else:
1832
+ layer_types = default_block * depth
1833
+
1834
+ self.layer_types = layer_types
1835
+ self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1836
+
1837
+ assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1838
+
1839
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1840
+
1841
+ # stochastic depth
1842
+
1843
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1844
+
1845
+ # structured dropout for cross attending
1846
+
1847
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1848
+
1849
+ # calculate token shifting
1850
+
1851
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1852
+
1853
+ # whether it has post norm
1854
+
1855
+ self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1856
+
1857
+ # iterate and construct layers
1858
+
1859
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1860
+ is_last_layer = ind == (len(self.layer_types) - 1)
1861
+
1862
+ if layer_type == 'a':
1863
+ layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1864
+ elif layer_type == 'c':
1865
+ layer = Attention(dim, heads = heads, **attn_kwargs)
1866
+ elif layer_type == 'f':
1867
+ layer = FeedForward(dim, **ff_kwargs)
1868
+ layer = layer if not macaron else Scale(0.5, layer)
1869
+ else:
1870
+ raise Exception(f'invalid layer type {layer_type}')
1871
+
1872
+ if layer_shift_tokens > 0:
1873
+ shift_range_upper = layer_shift_tokens + 1
1874
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1875
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1876
+
1877
+ residual_fn = GRUGating if gate_residual else Residual
1878
+ residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1879
+
1880
+ pre_branch_norm = norm_fn() if pre_norm else None
1881
+ post_branch_norm = norm_fn() if sandwich_norm else None
1882
+ post_main_norm = norm_fn() if not pre_norm else None
1883
+
1884
+ norms = nn.ModuleList([
1885
+ pre_branch_norm,
1886
+ post_branch_norm,
1887
+ post_main_norm
1888
+ ])
1889
+
1890
+ self.layers.append(nn.ModuleList([
1891
+ norms,
1892
+ layer,
1893
+ residual
1894
+ ]))
1895
+
1896
+ def forward(
1897
+ self,
1898
+ x,
1899
+ context = None,
1900
+ mask = None,
1901
+ context_mask = None,
1902
+ attn_mask = None,
1903
+ self_attn_kv_mask = None,
1904
+ mems = None,
1905
+ seq_start_pos: Optional[Tensor] = None,
1906
+ cache: Optional[LayerIntermediates] = None,
1907
+ cache_age = 1,
1908
+ return_hiddens = False
1909
+ ):
1910
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1911
+
1912
+ # initialize accums
1913
+
1914
+ hiddens = []
1915
+ layer_hiddens = []
1916
+ intermediates = []
1917
+
1918
+ prev_attn = None
1919
+ prev_cross_attn = None
1920
+
1921
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1922
+
1923
+ # handle left padded sequences
1924
+
1925
+ if exists(seq_start_pos):
1926
+ seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
1927
+ left_pad_mask = seq_arange >= seq_start_pos[..., None]
1928
+
1929
+ if exists(self_attn_kv_mask):
1930
+ self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
1931
+ else:
1932
+ self_attn_kv_mask = left_pad_mask
1933
+
1934
+ # rotary positions
1935
+
1936
+ rotary_pos_emb = None
1937
+
1938
+ if exists(self.rotary_pos_emb):
1939
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1940
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1941
+
1942
+ # assume cached key / values
1943
+
1944
+ attn_cache = []
1945
+
1946
+ if exists(cache):
1947
+ assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1948
+
1949
+ if cache_age > 0:
1950
+ x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1951
+
1952
+ attn_cache = cache.attn_intermediates
1953
+
1954
+ iter_attn_cache = iter(attn_cache)
1955
+
1956
+ # outer residual - for resiDual paper
1957
+
1958
+ outer_residual = x * self.resi_dual_scale
1959
+
1960
+ # get layers to be executed
1961
+
1962
+ layer_variables = (
1963
+ self.layer_types,
1964
+ self.layers,
1965
+ self.layer_dropouts
1966
+ )
1967
+
1968
+ layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1969
+
1970
+ # go through the attention and feedforward layers
1971
+
1972
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1973
+ is_last = ind == (len(self.layers) - 1)
1974
+
1975
+ if self.training and layer_dropout > 0. and random() < layer_dropout:
1976
+ continue
1977
+
1978
+ if layer_type == 'a':
1979
+ if return_hiddens:
1980
+ hiddens.append(x)
1981
+ layer_mem = mems.pop(0) if mems else None
1982
+
1983
+ if layer_type == 'c':
1984
+ if self.training and self.cross_attn_tokens_dropout > 0.:
1985
+ context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1986
+
1987
+ inner_residual = x
1988
+
1989
+ if return_hiddens:
1990
+ layer_hiddens.append(x)
1991
+
1992
+ pre_norm, post_branch_norm, post_main_norm = norm
1993
+
1994
+ if exists(pre_norm):
1995
+ x = pre_norm(x)
1996
+
1997
+ if layer_type == 'a':
1998
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1999
+ elif layer_type == 'c':
2000
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
2001
+ elif layer_type == 'f':
2002
+ out = block(x)
2003
+
2004
+ if self.resi_dual:
2005
+ outer_residual = outer_residual + out * self.resi_dual_scale
2006
+
2007
+ if exists(post_branch_norm):
2008
+ out = post_branch_norm(out)
2009
+
2010
+ x = residual_fn(out, inner_residual)
2011
+
2012
+ if layer_type in ('a', 'c') and return_hiddens:
2013
+ intermediates.append(inter)
2014
+
2015
+ if layer_type == 'a' and self.residual_attn:
2016
+ prev_attn = inter.pre_softmax_attn
2017
+ elif layer_type == 'c' and self.cross_residual_attn:
2018
+ prev_cross_attn = inter.pre_softmax_attn
2019
+
2020
+ if exists(post_main_norm):
2021
+ x = post_main_norm(x)
2022
+
2023
+ if return_hiddens:
2024
+ layer_hiddens.append(x)
2025
+
2026
+ if self.resi_dual:
2027
+ x = x + self.final_norm(outer_residual)
2028
+ else:
2029
+ x = self.final_norm(x)
2030
+
2031
+ if not return_hiddens:
2032
+ return x
2033
+
2034
+ intermediates = LayerIntermediates(
2035
+ hiddens = hiddens,
2036
+ attn_intermediates = intermediates,
2037
+ layer_hiddens = layer_hiddens
2038
+ )
2039
+
2040
+ return x, intermediates
2041
+
2042
+ class Encoder(AttentionLayers):
2043
+ def __init__(self, **kwargs):
2044
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
2045
+ super().__init__(causal = False, **kwargs)
2046
+
2047
+ class Decoder(AttentionLayers):
2048
+ def __init__(self, **kwargs):
2049
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
2050
+ super().__init__(causal = True, **kwargs)
2051
+
2052
+ class CrossAttender(AttentionLayers):
2053
+ def __init__(self, **kwargs):
2054
+ super().__init__(cross_attend = True, only_cross = True, **kwargs)
2055
+
2056
+ class ViTransformerWrapper(nn.Module):
2057
+ def __init__(
2058
+ self,
2059
+ *,
2060
+ image_size,
2061
+ patch_size,
2062
+ attn_layers,
2063
+ channels = 3,
2064
+ num_classes = None,
2065
+ post_emb_norm = False,
2066
+ num_register_tokens = 0,
2067
+ emb_dropout = 0.
2068
+ ):
2069
+ super().__init__()
2070
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
2071
+ assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
2072
+ dim = attn_layers.dim
2073
+ num_patches = (image_size // patch_size) ** 2
2074
+ patch_dim = channels * patch_size ** 2
2075
+
2076
+ self.patch_size = patch_size
2077
+
2078
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
2079
+
2080
+ has_register_tokens = num_register_tokens > 0
2081
+ self.has_register_tokens = has_register_tokens
2082
+
2083
+ if has_register_tokens:
2084
+ self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
2085
+
2086
+ self.patch_to_embedding = nn.Sequential(
2087
+ nn.LayerNorm(patch_dim),
2088
+ nn.Linear(patch_dim, dim),
2089
+ nn.LayerNorm(dim)
2090
+ )
2091
+
2092
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2093
+ self.dropout = nn.Dropout(emb_dropout)
2094
+
2095
+ self.attn_layers = attn_layers
2096
+
2097
+ self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
2098
+
2099
+ def forward(
2100
+ self,
2101
+ img,
2102
+ return_embeddings = False
2103
+ ):
2104
+ b, p = img.shape[0], self.patch_size
2105
+
2106
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
2107
+ x = self.patch_to_embedding(x)
2108
+ n = x.shape[1]
2109
+
2110
+ x = x + self.pos_embedding[:, :n]
2111
+
2112
+ x = self.post_emb_norm(x)
2113
+ x = self.dropout(x)
2114
+
2115
+ if self.has_register_tokens:
2116
+ r = repeat(self.register_tokens, 'n d -> b n d', b = b)
2117
+ x, ps = pack((x, r), 'b * d')
2118
+
2119
+ x = self.attn_layers(x)
2120
+
2121
+ if self.has_register_tokens:
2122
+ x, _ = unpack(x, ps, 'b * d')
2123
+
2124
+ if not exists(self.mlp_head) or return_embeddings:
2125
+ return x
2126
+
2127
+ x = x.mean(dim = -2)
2128
+ return self.mlp_head(x)
2129
+
2130
+ class TransformerWrapper(nn.Module):
2131
+ def __init__(
2132
+ self,
2133
+ *,
2134
+ num_tokens,
2135
+ max_seq_len,
2136
+ attn_layers,
2137
+ emb_dim = None,
2138
+ max_mem_len = 0,
2139
+ shift_mem_down = 0,
2140
+ emb_dropout = 0.,
2141
+ post_emb_norm = False,
2142
+ num_memory_tokens = None,
2143
+ memory_tokens_interspersed_every = None,
2144
+ tie_embedding = False,
2145
+ logits_dim = None,
2146
+ use_abs_pos_emb = True,
2147
+ scaled_sinu_pos_emb = False,
2148
+ l2norm_embed = False,
2149
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
2150
+ attn_z_loss_weight = 1e-4,
2151
+ ):
2152
+ super().__init__()
2153
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2154
+
2155
+ dim = attn_layers.dim
2156
+ emb_dim = default(emb_dim, dim)
2157
+ self.emb_dim = emb_dim
2158
+ self.num_tokens = num_tokens
2159
+
2160
+ self.max_seq_len = max_seq_len
2161
+ self.max_mem_len = max_mem_len
2162
+ self.shift_mem_down = shift_mem_down
2163
+
2164
+ self.l2norm_embed = l2norm_embed
2165
+ self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
2166
+
2167
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2168
+ self.pos_emb = always(0)
2169
+ elif scaled_sinu_pos_emb:
2170
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
2171
+ else:
2172
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
2173
+
2174
+ self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2175
+
2176
+ self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
2177
+ self.emb_dropout = nn.Dropout(emb_dropout)
2178
+
2179
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
2180
+ self.attn_layers = attn_layers
2181
+
2182
+ self.init_()
2183
+
2184
+ logits_dim = default(logits_dim, num_tokens)
2185
+ self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
2186
+
2187
+ # memory tokens (like [cls]) from Memory Transformers paper
2188
+
2189
+ num_memory_tokens = default(num_memory_tokens, 0)
2190
+ self.num_memory_tokens = num_memory_tokens
2191
+ if num_memory_tokens > 0:
2192
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
2193
+
2194
+ self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2195
+
2196
+ # whether can do cached kv decoding
2197
+
2198
+ self.can_cache_kv = self.num_memory_tokens == 0
2199
+
2200
+ def init_(self):
2201
+ if self.l2norm_embed:
2202
+ nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2203
+ if not isinstance(self.pos_emb, always):
2204
+ nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2205
+ return
2206
+
2207
+ nn.init.kaiming_normal_(self.token_emb.emb.weight)
2208
+
2209
+ def forward(
2210
+ self,
2211
+ x,
2212
+ return_embeddings = False,
2213
+ return_logits_and_embeddings = False,
2214
+ return_intermediates = False,
2215
+ mask = None,
2216
+ return_mems = False,
2217
+ return_attn = False,
2218
+ mems = None,
2219
+ pos = None,
2220
+ prepend_embeds = None,
2221
+ sum_embeds = None,
2222
+ return_attn_z_loss = False,
2223
+ attn_z_loss_weight = 1e-4,
2224
+ seq_start_pos = None,
2225
+ cache: Optional[LayerIntermediates] = None,
2226
+ **kwargs
2227
+ ):
2228
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2229
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2230
+
2231
+ # absolute positional embedding
2232
+
2233
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
2234
+ pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2235
+ x = self.token_emb(x) + pos_emb
2236
+
2237
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
2238
+
2239
+ if exists(sum_embeds):
2240
+ x = x + sum_embeds
2241
+
2242
+ # post embedding norm, purportedly leads to greater stabilization
2243
+
2244
+ x = self.post_emb_norm(x)
2245
+
2246
+ # whether to append embeds, as in PaLI, for image embeddings
2247
+
2248
+ if exists(prepend_embeds):
2249
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2250
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2251
+
2252
+ x = torch.cat((prepend_embeds, x), dim = -2)
2253
+
2254
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2255
+
2256
+ if emb_frac_gradient < 1:
2257
+ assert emb_frac_gradient > 0
2258
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
2259
+
2260
+ # embedding dropout
2261
+
2262
+ x = self.emb_dropout(x)
2263
+
2264
+ x = self.project_emb(x)
2265
+
2266
+ if has_memory_tokens:
2267
+ mem_every = self.memory_tokens_interspersed_every
2268
+
2269
+ if exists(mem_every):
2270
+ assert mem_every > 0
2271
+ assert isinstance(self.attn_layers, Decoder), 'only for decoder'
2272
+ next_seq_len = math.ceil(n / mem_every) * mem_every
2273
+
2274
+ x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
2275
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
2276
+
2277
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
2278
+ x, mem_packed_shape = pack((mem, x), 'b * d')
2279
+
2280
+ # auto-handle masking after appending memory tokens
2281
+ if not exists(mem_every) and exists(mask):
2282
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
2283
+
2284
+ if exists(mem_every):
2285
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2286
+
2287
+ if self.shift_mem_down and exists(mems):
2288
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2289
+ mems = [*mems_r, *mems_l]
2290
+
2291
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2292
+
2293
+ if has_memory_tokens:
2294
+ if exists(mem_every):
2295
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
2296
+
2297
+ mem, x = unpack(x, mem_packed_shape, 'b * d')
2298
+
2299
+ if exists(mem_every):
2300
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2301
+
2302
+ x = x[:, :n]
2303
+
2304
+ if return_logits_and_embeddings:
2305
+ out = (self.to_logits(x), x)
2306
+ elif return_embeddings:
2307
+ out = x
2308
+ else:
2309
+ out = self.to_logits(x)
2310
+
2311
+ if return_attn_z_loss:
2312
+ pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2313
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2314
+ return_intermediates = True
2315
+
2316
+ if return_mems:
2317
+ hiddens = intermediates.hiddens
2318
+ new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2319
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2320
+
2321
+ if not return_intermediates:
2322
+ return out, new_mems
2323
+
2324
+ intermediates.mems = new_mems
2325
+
2326
+ if return_intermediates:
2327
+ return out, intermediates
2328
+
2329
+ if return_attn:
2330
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2331
+ return out, attn_maps
2332
+
2333
+ return out
2334
+
2335
+ class ContinuousTransformerWrapper(nn.Module):
2336
+ def __init__(
2337
+ self,
2338
+ *,
2339
+ max_seq_len,
2340
+ attn_layers,
2341
+ dim_in = None,
2342
+ dim_out = None,
2343
+ emb_dim = None,
2344
+ max_mem_len = 0,
2345
+ post_emb_norm = False,
2346
+ emb_dropout = 0.,
2347
+ use_abs_pos_emb = True,
2348
+ scaled_sinu_pos_emb = False
2349
+ ):
2350
+ super().__init__()
2351
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2352
+
2353
+ dim = attn_layers.dim
2354
+
2355
+ self.max_seq_len = max_seq_len
2356
+
2357
+ self.max_mem_len = max_mem_len
2358
+
2359
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2360
+ self.pos_emb = always(0)
2361
+ elif scaled_sinu_pos_emb:
2362
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
2363
+ else:
2364
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
2365
+
2366
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2367
+ self.emb_dropout = nn.Dropout(emb_dropout)
2368
+
2369
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
2370
+
2371
+ self.attn_layers = attn_layers
2372
+
2373
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
2374
+
2375
+ def forward(
2376
+ self,
2377
+ x,
2378
+ return_embeddings = False,
2379
+ return_intermediates = False,
2380
+ return_mems = False,
2381
+ mask = None,
2382
+ return_attn = False,
2383
+ mems = None,
2384
+ pos = None,
2385
+ prepend_embeds = None,
2386
+ **kwargs
2387
+ ):
2388
+ x = self.project_in(x)
2389
+ x = x + self.pos_emb(x, pos = pos)
2390
+
2391
+ x = self.post_emb_norm(x)
2392
+
2393
+ # whether to append embeds, as in PaLI, for image embeddings
2394
+
2395
+ if exists(prepend_embeds):
2396
+ _, prepend_dim = prepend_embeds.shape[1:]
2397
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
2398
+
2399
+ x = torch.cat((prepend_embeds, x), dim = -2)
2400
+
2401
+ x = self.emb_dropout(x)
2402
+
2403
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
2404
+
2405
+ out = self.project_out(x) if not return_embeddings else x
2406
+
2407
+ if return_intermediates:
2408
+ return out, intermediates
2409
+
2410
+ if return_mems:
2411
+ hiddens = intermediates.hiddens
2412
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
2413
+ return out, new_mems
2414
+
2415
+ if return_attn:
2416
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2417
+ return out, attn_maps
2418
+
2419
+ return out
2420
+
2421
+ class XTransformer(nn.Module):
2422
+ def __init__(
2423
+ self,
2424
+ *,
2425
+ dim,
2426
+ tie_token_emb = False,
2427
+ ignore_index = -100,
2428
+ pad_value = 0,
2429
+ cross_attn_tokens_dropout = 0.,
2430
+ **kwargs
2431
+ ):
2432
+ super().__init__()
2433
+ enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
2434
+ dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
2435
+
2436
+ assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
2437
+ enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
2438
+ enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
2439
+ enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
2440
+ enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
2441
+ enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
2442
+
2443
+ dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
2444
+ dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
2445
+ dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
2446
+ dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
2447
+
2448
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
2449
+
2450
+ self.encoder = TransformerWrapper(
2451
+ **enc_transformer_kwargs,
2452
+ attn_layers = Encoder(dim = dim, **enc_kwargs)
2453
+ )
2454
+
2455
+ self.decoder = TransformerWrapper(
2456
+ **dec_transformer_kwargs,
2457
+ attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
2458
+ )
2459
+
2460
+ if tie_token_emb:
2461
+ self.decoder.token_emb = self.encoder.token_emb
2462
+
2463
+ self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
2464
+
2465
+ @torch.no_grad()
2466
+ def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
2467
+ encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
2468
+ return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
2469
+
2470
+ def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
2471
+
2472
+ if exists(src_prepend_embeds) and exists(mask):
2473
+ mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
2474
+
2475
+ enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
2476
+
2477
+ if self.training and self.cross_attn_tokens_dropout > 0:
2478
+ enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2479
+
2480
+ out = self.decoder(tgt, context = enc, context_mask = mask)
2481
+ return out