roychao19477 commited on
Commit
bd9ffb1
·
1 Parent(s): a66fd6a

Add application file

Browse files
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Avse Dev Only
3
- emoji: 🔥
4
- colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: 'avse_dev_only (Free HF user version : limited resolution)'
 
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Dev
3
+ colorFrom: purple
 
4
  colorTo: indigo
5
  sdk: gradio
6
+ sdk_version: 5.31.0
7
  app_file: app.py
8
  pinned: false
9
+ short_description: Dev
10
+ tags:
11
+ - Useless
12
  ---
 
 
app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shlex
2
+ import subprocess
3
+ import spaces
4
+ import torch
5
+ import os
6
+ import shutil
7
+ import glob
8
+ import gradio as gr
9
+
10
+ # install packages for mamba
11
+ def install_mamba():
12
+ subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
13
+
14
+ def clone_github():
15
+ subprocess.run([
16
+ "git", "clone",
17
+ f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git",
18
+ ])
19
+ # move all files except README.md
20
+ for item in glob.glob("for_HF_AVSEMamba/*"):
21
+ if os.path.basename(item) != "README.md":
22
+ if os.path.isdir(item):
23
+ shutil.move(item, ".")
24
+ else:
25
+ shutil.move(item, os.path.join(".", os.path.basename(item)))
26
+
27
+ #shutil.rmtree("tmp_repo")
28
+ #subprocess.run(["ls"], check=True)
29
+
30
+ install_mamba()
31
+ clone_github()
32
+
33
+ ABOUT = """
34
+ # SEMamba: Speech Enhancement
35
+ A Mamba-based model that denoises real-world audio.
36
+ Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
37
+ """
38
+
39
+
40
+ import torch
41
+ import ffmpeg
42
+ import torchaudio
43
+ import torchaudio.transforms as T
44
+ import yaml
45
+ import librosa
46
+ import librosa.display
47
+ import matplotlib
48
+ import numpy as np
49
+ import soundfile as sf
50
+ import matplotlib.pyplot as plt
51
+ from models.stfts import mag_phase_stft, mag_phase_istft
52
+ from models.generator import SEMamba
53
+ from models.pcs400 import cal_pcs
54
+ from ultralytics import YOLO
55
+ import supervision as sv
56
+
57
+ import gradio as gr
58
+ import cv2
59
+ import os
60
+ import tempfile
61
+ from ultralytics import YOLO
62
+ from moviepy import ImageSequenceClip
63
+ from moviepy.video import fx as vfx
64
+ from scipy.io import wavfile
65
+ from avse_code import run_avse
66
+
67
+ # Load face detector
68
+ model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
69
+
70
+
71
+ from decord import VideoReader, cpu
72
+ from model import AVSEModule
73
+ from config import sampling_rate
74
+ import spaces
75
+
76
+ # Load model once globally
77
+ #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
78
+ #model = AVSEModule.load_from_checkpoint(ckpt_path)
79
+ avse_model = AVSEModule()
80
+ #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
81
+ avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
82
+ avse_model.load_state_dict(avse_state_dict, strict=True)
83
+ avse_model.to("cuda")
84
+ avse_model.eval()
85
+
86
+ @spaces.GPU
87
+ def run_avse_inference(video_path, audio_path):
88
+ estimated = run_avse(video_path, audio_path)
89
+ # Load audio
90
+ #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
91
+ #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
92
+ noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
93
+
94
+ # Norm.
95
+ #noisy = noisy * (0.8 / np.max(np.abs(noisy)))
96
+
97
+ # Load grayscale video
98
+ vr = VideoReader(video_path, ctx=cpu(0))
99
+ frames = vr.get_batch(list(range(len(vr)))).asnumpy()
100
+ bg_frames = np.array([
101
+ cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))
102
+ ]).astype(np.float32)
103
+ bg_frames /= 255.0
104
+
105
+
106
+ # Combine into input dict (match what model.enhance expects)
107
+ data = {
108
+ "noisy_audio": noisy,
109
+ "video_frames": bg_frames[np.newaxis, ...]
110
+ }
111
+
112
+ with torch.no_grad():
113
+ estimated = avse_model.enhance(data).reshape(-1)
114
+
115
+ # Save result
116
+ tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
117
+ sf.write(tmp_wav, estimated, samplerate=sampling_rate)
118
+
119
+ return tmp_wav
120
+
121
+
122
+ def extract_resampled_audio(video_path, target_sr=16000):
123
+ # Step 1: extract audio via torchaudio
124
+ # (moviepy will still extract it to wav temp file)
125
+ tmp_audio_path = tempfile.mktemp(suffix=".wav")
126
+ subprocess.run(["ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", tmp_audio_path])
127
+
128
+ # Step 2: Load and resample
129
+ waveform, sr = torchaudio.load(tmp_audio_path)
130
+ if sr != target_sr:
131
+ resampler = T.Resample(orig_freq=sr, new_freq=target_sr)
132
+ waveform = resampler(waveform)
133
+
134
+ # Step 3: Save resampled audio
135
+ resampled_audio_path = tempfile.mktemp(suffix="_16k.wav")
136
+ torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
137
+ return resampled_audio_path
138
+
139
+ @spaces.GPU
140
+ def extract_faces(video_file):
141
+ # Step 0: Check resolution
142
+ cap = cv2.VideoCapture(video_file)
143
+ width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
144
+ height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
145
+ cap.release()
146
+
147
+ # Step 1: Downsample if needed
148
+ if width > 1280 or height > 720:
149
+ resized_path = tempfile.mktemp(suffix=".mp4")
150
+ subprocess.run([
151
+ "ffmpeg", "-y", "-i", video_file,
152
+ "-vf", "scale='min(1280,iw)':-2",
153
+ "-c:v", "libx264", "-crf", "28",
154
+ "-preset", "fast", "-an", resized_path
155
+ ])
156
+ video_file = resized_path
157
+
158
+ cap = cv2.VideoCapture(video_file)
159
+ fps = cap.get(cv2.CAP_PROP_FPS)
160
+ frames = []
161
+
162
+ while True:
163
+ ret, frame = cap.read()
164
+ if not ret:
165
+ break
166
+
167
+ # Inference
168
+ results = model(frame, verbose=False)[0]
169
+ for box in results.boxes:
170
+ # version 1
171
+ # x1, y1, x2, y2 = map(int, box.xyxy[0])
172
+
173
+ # version 2
174
+ h, w, _ = frame.shape
175
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
176
+ pad_ratio = 0.5 # 30% padding
177
+
178
+ dx = (x2 - x1) * pad_ratio
179
+ dy = (y2 - y1) * pad_ratio
180
+
181
+ x1 = int(max(0, x1 - dx))
182
+ y1 = int(max(0, y1 - dy))
183
+ x2 = int(min(w, x2 + dx))
184
+ y2 = int(min(h, y2 + dy))
185
+ # Added for v3
186
+ shift_down = int(0.1 * (y2 - y1))
187
+ y1 = int(min(max(0, y1 + shift_down), h))
188
+ y2 = int(min(max(0, y2 + shift_down), h))
189
+ face_crop = frame[y1:y2, x1:x2]
190
+ if face_crop.size != 0:
191
+ resized = cv2.resize(face_crop, (224, 224))
192
+ frames.append(resized)
193
+
194
+ #h_crop, w_crop = face_crop.shape[:2]
195
+ #side = min(h_crop, w_crop)
196
+ #start_y = (h_crop - side) // 2
197
+ #start_x = (w_crop - side) // 2
198
+ #square_crop = face_crop[start_y:start_y+side, start_x:start_x+side]
199
+ #resized = cv2.resize(square_crop, (224, 224))
200
+ #frames.append(resized)
201
+
202
+ break # only one face per frame
203
+
204
+ cap.release()
205
+
206
+ # Save as video
207
+ tmpdir = tempfile.mkdtemp()
208
+ output_path = os.path.join(tmpdir, "face_only_video.mp4")
209
+ #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=25)
210
+ #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=fps)
211
+ clip = ImageSequenceClip(
212
+ [cv2.cvtColor(cv2.resize(f, (224, 224)), cv2.COLOR_BGR2RGB) for f in frames],
213
+ fps=fps
214
+ ).fx(vfx.flip_vertical)
215
+ clip.write_videofile(output_path, codec="libx264", audio=False, fps=25)
216
+
217
+ # Save audio from original, resampled to 16kHz
218
+ audio_path = os.path.join(tmpdir, "audio_16k.wav")
219
+
220
+ # Extract audio using ffmpeg-python (more robust than moviepy)
221
+ ffmpeg.input(video_file).output(
222
+ audio_path,
223
+ ar=16000, # resample to 16k
224
+ ac=1, # mono
225
+ format='wav',
226
+ vn=None # no video
227
+ ).run(overwrite_output=True)
228
+
229
+
230
+
231
+
232
+ # ------------------------------- #
233
+ # AVSE models
234
+
235
+ enhanced_audio_path = run_avse_inference(output_path, audio_path)
236
+
237
+
238
+ return output_path, enhanced_audio_path
239
+ #return output_path, audio_path
240
+
241
+ iface = gr.Interface(
242
+ fn=extract_faces,
243
+ inputs=gr.Video(label="Upload or record your video"),
244
+ outputs=[
245
+ gr.Video(label="Detected Face Only Video"),
246
+ #gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
247
+ gr.Audio(label="Enhanced Audio", type="filepath")
248
+ ],
249
+ title="Face Detector",
250
+ description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
251
+ )
252
+
253
+ iface.launch()
254
+
255
+
256
+
257
+ ckpt = "ckpts/SEMamba_advanced.pth"
258
+ cfg_f = "recipes/SEMamba_advanced.yaml"
259
+
260
+ # load config
261
+ with open(cfg_f, 'r') as f:
262
+ cfg = yaml.safe_load(f)
263
+
264
+
265
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
266
+ device = "cuda"
267
+ model = SEMamba(cfg).to(device)
268
+ #sdict = torch.load(ckpt, map_location=device)
269
+ #model.load_state_dict(sdict["generator"])
270
+ #model.eval()
271
+
272
+ @spaces.GPU
273
+ def enhance(filepath, model_name):
274
+ # Load model based on selection
275
+ ckpt_path = {
276
+ "VCTK-Demand": "ckpts/SEMamba_advanced.pth",
277
+ "VCTK+DNS": "ckpts/vd.pth"
278
+ }[model_name]
279
+
280
+ print("Loading:", ckpt_path)
281
+ model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
282
+ model.eval()
283
+ with torch.no_grad():
284
+ # load & resample
285
+ wav, orig_sr = librosa.load(filepath, sr=None)
286
+ noisy_wav = wav.copy()
287
+ if orig_sr != 16000:
288
+ wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
289
+ x = torch.from_numpy(wav).float().to(device)
290
+ norm = torch.sqrt(len(x)/torch.sum(x**2))
291
+ #x = (x * norm).unsqueeze(0)
292
+ x = (x * norm)
293
+
294
+ # split into 4s segments (64000 samples)
295
+ segment_len = 4 * 16000
296
+ chunks = x.split(segment_len)
297
+ enhanced_chunks = []
298
+
299
+ for chunk in chunks:
300
+ if len(chunk) < segment_len:
301
+ #pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
302
+ pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
303
+ chunk = torch.cat([chunk, pad])
304
+ chunk = chunk.unsqueeze(0)
305
+
306
+ amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
307
+ amp2, pha2, _ = model(amp, pha)
308
+ out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
309
+ out = (out / norm).squeeze(0)
310
+ enhanced_chunks.append(out)
311
+
312
+ out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
313
+
314
+ # back to original rate
315
+ if orig_sr != 16000:
316
+ out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
317
+
318
+ # Normalize
319
+ peak = np.max(np.abs(out))
320
+ if peak > 0.05:
321
+ out = out / peak * 0.85
322
+
323
+ # write file
324
+ sf.write("enhanced.wav", out, orig_sr)
325
+
326
+ # spectrograms
327
+ fig, axs = plt.subplots(1, 2, figsize=(16, 4))
328
+
329
+ # noisy
330
+ D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
331
+ S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
332
+ librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
333
+ axs[0].set_title("Noisy Spectrogram")
334
+
335
+ # enhanced
336
+ D_clean = librosa.stft(out, n_fft=512, hop_length=256)
337
+ S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
338
+ librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
339
+ #librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
340
+ axs[1].set_title("Enhanced Spectrogram")
341
+
342
+ plt.tight_layout()
343
+
344
+ return "enhanced.wav", fig
345
+
346
+ #with gr.Blocks() as demo:
347
+ # gr.Markdown(ABOUT)
348
+ # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
349
+ # enhance_btn = gr.Button("Enhance")
350
+ # output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
351
+ # plot_output = gr.Plot(label="Spectrograms")
352
+ #
353
+ # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
354
+ #
355
+ #demo.queue().launch()
356
+
357
+ with gr.Blocks() as demo:
358
+ gr.Markdown(ABOUT)
359
+ input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
360
+ model_choice = gr.Radio(
361
+ label="Choose Model (The use of VCTK+DNS is recommended)",
362
+ choices=["VCTK-Demand", "VCTK+DNS"],
363
+ value="VCTK-Demand"
364
+ )
365
+ enhance_btn = gr.Button("Enhance")
366
+ output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
367
+ plot_output = gr.Plot(label="Spectrograms")
368
+
369
+ enhance_btn.click(
370
+ fn=enhance,
371
+ inputs=[input_audio, model_choice],
372
+ outputs=[output_audio, plot_output]
373
+ )
374
+ gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
375
+
376
+ demo.queue().launch()
mamba_ssm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
mamba_ssm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __version__ = "1.2.2"
2
+
3
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from mamba_ssm.modules.mamba_simple import Mamba
5
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
mamba_ssm/models/__init__.py ADDED
File without changes
mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ n_layer: int = 64
9
+ vocab_size: int = 50277
10
+ ssm_cfg: dict = field(default_factory=dict)
11
+ rms_norm: bool = True
12
+ residual_in_fp32: bool = True
13
+ fused_add_norm: bool = True
14
+ pad_vocab_size_multiple: int = 8
15
+ tie_embeddings: bool = True
mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+
8
+ from collections import namedtuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from mamba_ssm.models.config_mamba import MambaConfig
14
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
15
+ from mamba_ssm.utils.generation import GenerationMixin
16
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
17
+
18
+ try:
19
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
20
+ except ImportError:
21
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
22
+
23
+
24
+ def create_block(
25
+ d_model,
26
+ ssm_cfg=None,
27
+ norm_epsilon=1e-5,
28
+ rms_norm=False,
29
+ residual_in_fp32=False,
30
+ fused_add_norm=False,
31
+ layer_idx=None,
32
+ device=None,
33
+ dtype=None,
34
+ ):
35
+ if ssm_cfg is None:
36
+ ssm_cfg = {}
37
+ factory_kwargs = {"device": device, "dtype": dtype}
38
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
39
+ norm_cls = partial(
40
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
41
+ )
42
+ block = Block(
43
+ d_model,
44
+ mixer_cls,
45
+ norm_cls=norm_cls,
46
+ fused_add_norm=fused_add_norm,
47
+ residual_in_fp32=residual_in_fp32,
48
+ )
49
+ block.layer_idx = layer_idx
50
+ return block
51
+
52
+
53
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
54
+ def _init_weights(
55
+ module,
56
+ n_layer,
57
+ initializer_range=0.02, # Now only used for embedding layer.
58
+ rescale_prenorm_residual=True,
59
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
60
+ ):
61
+ if isinstance(module, nn.Linear):
62
+ if module.bias is not None:
63
+ if not getattr(module.bias, "_no_reinit", False):
64
+ nn.init.zeros_(module.bias)
65
+ elif isinstance(module, nn.Embedding):
66
+ nn.init.normal_(module.weight, std=initializer_range)
67
+
68
+ if rescale_prenorm_residual:
69
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
70
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
71
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
72
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
73
+ #
74
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
75
+ for name, p in module.named_parameters():
76
+ if name in ["out_proj.weight", "fc2.weight"]:
77
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
78
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
79
+ # We need to reinit p since this code could be called multiple times
80
+ # Having just p *= scale would repeatedly scale it down
81
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
82
+ with torch.no_grad():
83
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
84
+
85
+
86
+ class MixerModel(nn.Module):
87
+ def __init__(
88
+ self,
89
+ d_model: int,
90
+ n_layer: int,
91
+ vocab_size: int,
92
+ ssm_cfg=None,
93
+ norm_epsilon: float = 1e-5,
94
+ rms_norm: bool = False,
95
+ initializer_cfg=None,
96
+ fused_add_norm=False,
97
+ residual_in_fp32=False,
98
+ device=None,
99
+ dtype=None,
100
+ ) -> None:
101
+ factory_kwargs = {"device": device, "dtype": dtype}
102
+ super().__init__()
103
+ self.residual_in_fp32 = residual_in_fp32
104
+
105
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
106
+
107
+ # We change the order of residual and layer norm:
108
+ # Instead of LN -> Attn / MLP -> Add, we do:
109
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
110
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
111
+ # This is for performance reason: we can fuse add + layer_norm.
112
+ self.fused_add_norm = fused_add_norm
113
+ if self.fused_add_norm:
114
+ if layer_norm_fn is None or rms_norm_fn is None:
115
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
116
+
117
+ self.layers = nn.ModuleList(
118
+ [
119
+ create_block(
120
+ d_model,
121
+ ssm_cfg=ssm_cfg,
122
+ norm_epsilon=norm_epsilon,
123
+ rms_norm=rms_norm,
124
+ residual_in_fp32=residual_in_fp32,
125
+ fused_add_norm=fused_add_norm,
126
+ layer_idx=i,
127
+ **factory_kwargs,
128
+ )
129
+ for i in range(n_layer)
130
+ ]
131
+ )
132
+
133
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
134
+ d_model, eps=norm_epsilon, **factory_kwargs
135
+ )
136
+
137
+ self.apply(
138
+ partial(
139
+ _init_weights,
140
+ n_layer=n_layer,
141
+ **(initializer_cfg if initializer_cfg is not None else {}),
142
+ )
143
+ )
144
+
145
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
146
+ return {
147
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
148
+ for i, layer in enumerate(self.layers)
149
+ }
150
+
151
+ def forward(self, input_ids, inference_params=None):
152
+ hidden_states = self.embedding(input_ids)
153
+ residual = None
154
+ for layer in self.layers:
155
+ hidden_states, residual = layer(
156
+ hidden_states, residual, inference_params=inference_params
157
+ )
158
+ if not self.fused_add_norm:
159
+ residual = (hidden_states + residual) if residual is not None else hidden_states
160
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
161
+ else:
162
+ # Set prenorm=False here since we don't need the residual
163
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
164
+ hidden_states = fused_add_norm_fn(
165
+ hidden_states,
166
+ self.norm_f.weight,
167
+ self.norm_f.bias,
168
+ eps=self.norm_f.eps,
169
+ residual=residual,
170
+ prenorm=False,
171
+ residual_in_fp32=self.residual_in_fp32,
172
+ )
173
+ return hidden_states
174
+
175
+
176
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
177
+
178
+ def __init__(
179
+ self,
180
+ config: MambaConfig,
181
+ initializer_cfg=None,
182
+ device=None,
183
+ dtype=None,
184
+ ) -> None:
185
+ self.config = config
186
+ d_model = config.d_model
187
+ n_layer = config.n_layer
188
+ vocab_size = config.vocab_size
189
+ ssm_cfg = config.ssm_cfg
190
+ rms_norm = config.rms_norm
191
+ residual_in_fp32 = config.residual_in_fp32
192
+ fused_add_norm = config.fused_add_norm
193
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
194
+ factory_kwargs = {"device": device, "dtype": dtype}
195
+
196
+ super().__init__()
197
+ if vocab_size % pad_vocab_size_multiple != 0:
198
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
199
+ self.backbone = MixerModel(
200
+ d_model=d_model,
201
+ n_layer=n_layer,
202
+ vocab_size=vocab_size,
203
+ ssm_cfg=ssm_cfg,
204
+ rms_norm=rms_norm,
205
+ initializer_cfg=initializer_cfg,
206
+ fused_add_norm=fused_add_norm,
207
+ residual_in_fp32=residual_in_fp32,
208
+ **factory_kwargs,
209
+ )
210
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
211
+
212
+ # Initialize weights and apply final processing
213
+ self.apply(
214
+ partial(
215
+ _init_weights,
216
+ n_layer=n_layer,
217
+ **(initializer_cfg if initializer_cfg is not None else {}),
218
+ )
219
+ )
220
+ self.tie_weights()
221
+
222
+ def tie_weights(self):
223
+ if self.config.tie_embeddings:
224
+ self.lm_head.weight = self.backbone.embedding.weight
225
+
226
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
227
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
228
+
229
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
230
+ """
231
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
232
+ num_last_tokens: if > 0, only return the logits for the last n tokens
233
+ """
234
+ hidden_states = self.backbone(input_ids, inference_params=inference_params)
235
+ if num_last_tokens > 0:
236
+ hidden_states = hidden_states[:, -num_last_tokens:]
237
+ lm_logits = self.lm_head(hidden_states)
238
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
239
+ return CausalLMOutput(logits=lm_logits)
240
+
241
+ @classmethod
242
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
243
+ config_data = load_config_hf(pretrained_model_name)
244
+ config = MambaConfig(**config_data)
245
+ model = cls(config, device=device, dtype=dtype, **kwargs)
246
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
247
+ return model
248
+
249
+ def save_pretrained(self, save_directory):
250
+ """
251
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
252
+ Save the model and its configuration file to a directory.
253
+ """
254
+ # Ensure save_directory exists
255
+ os.makedirs(save_directory, exist_ok=True)
256
+
257
+ # Save the model's state_dict
258
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
259
+ torch.save(self.state_dict(), model_path)
260
+
261
+ # Save the configuration of the model
262
+ config_path = os.path.join(save_directory, 'config.json')
263
+ with open(config_path, 'w') as f:
264
+ json.dump(self.config.__dict__, f)
mamba_ssm/modules/__init__.py ADDED
File without changes
mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
63
+
64
+ self.conv1d = nn.Conv1d(
65
+ in_channels=self.d_inner,
66
+ out_channels=self.d_inner,
67
+ bias=conv_bias,
68
+ kernel_size=d_conv,
69
+ groups=self.d_inner,
70
+ padding=d_conv - 1,
71
+ **factory_kwargs,
72
+ )
73
+
74
+ self.activation = "silu"
75
+ self.act = nn.SiLU()
76
+
77
+ self.x_proj = nn.Linear(
78
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
79
+ )
80
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
81
+
82
+ # Initialize special dt projection to preserve variance at initialization
83
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
84
+ if dt_init == "constant":
85
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
86
+ elif dt_init == "random":
87
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
92
+ dt = torch.exp(
93
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
94
+ + math.log(dt_min)
95
+ ).clamp(min=dt_init_floor)
96
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
97
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
98
+ with torch.no_grad():
99
+ self.dt_proj.bias.copy_(inv_dt)
100
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
101
+ self.dt_proj.bias._no_reinit = True
102
+
103
+ # S4D real initialization
104
+ A = repeat(
105
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
106
+ "n -> d n",
107
+ d=self.d_inner,
108
+ ).contiguous()
109
+ A_log = torch.log(A) # Keep A_log in fp32
110
+ self.A_log = nn.Parameter(A_log)
111
+ self.A_log._no_weight_decay = True
112
+
113
+ # D "skip" parameter
114
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
115
+ self.D._no_weight_decay = True
116
+
117
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
118
+
119
+ def forward(self, hidden_states, inference_params=None):
120
+ """
121
+ hidden_states: (B, L, D)
122
+ Returns: same shape as hidden_states
123
+ """
124
+ batch, seqlen, dim = hidden_states.shape
125
+
126
+ conv_state, ssm_state = None, None
127
+ if inference_params is not None:
128
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
129
+ if inference_params.seqlen_offset > 0:
130
+ # The states are updated inplace
131
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
132
+ return out
133
+
134
+ # We do matmul and transpose BLH -> HBL at the same time
135
+ xz = rearrange(
136
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
137
+ "d (b l) -> b d l",
138
+ l=seqlen,
139
+ )
140
+ if self.in_proj.bias is not None:
141
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
142
+
143
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
144
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
145
+ if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
146
+ out = mamba_inner_fn(
147
+ xz,
148
+ self.conv1d.weight,
149
+ self.conv1d.bias,
150
+ self.x_proj.weight,
151
+ self.dt_proj.weight,
152
+ self.out_proj.weight,
153
+ self.out_proj.bias,
154
+ A,
155
+ None, # input-dependent B
156
+ None, # input-dependent C
157
+ self.D.float(),
158
+ delta_bias=self.dt_proj.bias.float(),
159
+ delta_softplus=True,
160
+ )
161
+ else:
162
+ x, z = xz.chunk(2, dim=1)
163
+ # Compute short convolution
164
+ if conv_state is not None:
165
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
166
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
167
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
168
+ if causal_conv1d_fn is None:
169
+ x = self.act(self.conv1d(x)[..., :seqlen])
170
+ else:
171
+ assert self.activation in ["silu", "swish"]
172
+ x = causal_conv1d_fn(
173
+ x=x,
174
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
175
+ bias=self.conv1d.bias,
176
+ activation=self.activation,
177
+ )
178
+
179
+ # We're careful here about the layout, to avoid extra transposes.
180
+ # We want dt to have d as the slowest moving dimension
181
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
182
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
183
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
184
+ dt = self.dt_proj.weight @ dt.t()
185
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
186
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
187
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
188
+ assert self.activation in ["silu", "swish"]
189
+ y = selective_scan_fn(
190
+ x,
191
+ dt,
192
+ A,
193
+ B,
194
+ C,
195
+ self.D.float(),
196
+ z=z,
197
+ delta_bias=self.dt_proj.bias.float(),
198
+ delta_softplus=True,
199
+ return_last_state=ssm_state is not None,
200
+ )
201
+ if ssm_state is not None:
202
+ y, last_state = y
203
+ ssm_state.copy_(last_state)
204
+ y = rearrange(y, "b d l -> b l d")
205
+ out = self.out_proj(y)
206
+ return out
207
+
208
+ def step(self, hidden_states, conv_state, ssm_state):
209
+ dtype = hidden_states.dtype
210
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
211
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
212
+ x, z = xz.chunk(2, dim=-1) # (B D)
213
+
214
+ # Conv step
215
+ if causal_conv1d_update is None:
216
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
217
+ conv_state[:, :, -1] = x
218
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
219
+ if self.conv1d.bias is not None:
220
+ x = x + self.conv1d.bias
221
+ x = self.act(x).to(dtype=dtype)
222
+ else:
223
+ x = causal_conv1d_update(
224
+ x,
225
+ conv_state,
226
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
227
+ self.conv1d.bias,
228
+ self.activation,
229
+ )
230
+
231
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
232
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
233
+ # Don't add dt_bias here
234
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
235
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
236
+
237
+ # SSM step
238
+ if selective_state_update is None:
239
+ # Discretize A and B
240
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
241
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
242
+ dB = torch.einsum("bd,bn->bdn", dt, B)
243
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
244
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
245
+ y = y + self.D.to(dtype) * x
246
+ y = y * self.act(z) # (B D)
247
+ else:
248
+ y = selective_state_update(
249
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
250
+ )
251
+
252
+ out = self.out_proj(y)
253
+ return out.unsqueeze(1), conv_state, ssm_state
254
+
255
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
256
+ device = self.out_proj.weight.device
257
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
258
+ conv_state = torch.zeros(
259
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
260
+ )
261
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
262
+ # ssm_dtype = torch.float32
263
+ ssm_state = torch.zeros(
264
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
265
+ )
266
+ return conv_state, ssm_state
267
+
268
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
269
+ assert self.layer_idx is not None
270
+ if self.layer_idx not in inference_params.key_value_memory_dict:
271
+ batch_shape = (batch_size,)
272
+ conv_state = torch.zeros(
273
+ batch_size,
274
+ self.d_model * self.expand,
275
+ self.d_conv,
276
+ device=self.conv1d.weight.device,
277
+ dtype=self.conv1d.weight.dtype,
278
+ )
279
+ ssm_state = torch.zeros(
280
+ batch_size,
281
+ self.d_model * self.expand,
282
+ self.d_state,
283
+ device=self.dt_proj.weight.device,
284
+ dtype=self.dt_proj.weight.dtype,
285
+ # dtype=torch.float32,
286
+ )
287
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
288
+ else:
289
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
290
+ # TODO: What if batch size changes between generation, and we reuse the same states?
291
+ if initialize_states:
292
+ conv_state.zero_()
293
+ ssm_state.zero_()
294
+ return conv_state, ssm_state
295
+
296
+
297
+ class Block(nn.Module):
298
+ def __init__(
299
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
300
+ ):
301
+ """
302
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
303
+
304
+ This Block has a slightly different structure compared to a regular
305
+ prenorm Transformer block.
306
+ The standard block is: LN -> MHA/MLP -> Add.
307
+ [Ref: https://arxiv.org/abs/2002.04745]
308
+ Here we have: Add -> LN -> Mixer, returning both
309
+ the hidden_states (output of the mixer) and the residual.
310
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
311
+ The residual needs to be provided (except for the very first block).
312
+ """
313
+ super().__init__()
314
+ self.residual_in_fp32 = residual_in_fp32
315
+ self.fused_add_norm = fused_add_norm
316
+ self.mixer = mixer_cls(dim)
317
+ self.norm = norm_cls(dim)
318
+ if self.fused_add_norm:
319
+ assert RMSNorm is not None, "RMSNorm import fails"
320
+ assert isinstance(
321
+ self.norm, (nn.LayerNorm, RMSNorm)
322
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
323
+
324
+ def forward(
325
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
326
+ ):
327
+ r"""Pass the input through the encoder layer.
328
+
329
+ Args:
330
+ hidden_states: the sequence to the encoder layer (required).
331
+ residual: hidden_states = Mixer(LN(residual))
332
+ """
333
+ if not self.fused_add_norm:
334
+ residual = (hidden_states + residual) if residual is not None else hidden_states
335
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
336
+ if self.residual_in_fp32:
337
+ residual = residual.to(torch.float32)
338
+ else:
339
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
340
+ hidden_states, residual = fused_add_norm_fn(
341
+ hidden_states,
342
+ self.norm.weight,
343
+ self.norm.bias,
344
+ residual=residual,
345
+ prenorm=True,
346
+ residual_in_fp32=self.residual_in_fp32,
347
+ eps=self.norm.eps,
348
+ )
349
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
350
+ return hidden_states, residual
351
+
352
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
353
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
mamba_ssm/ops/__init__.py ADDED
File without changes
mamba_ssm/ops/selective_scan_interface.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.cuda.amp import custom_bwd, custom_fwd
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ try:
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
16
+ import selective_scan_cuda
17
+
18
+
19
+ class SelectiveScanFn(torch.autograd.Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
23
+ return_last_state=False):
24
+ if u.stride(-1) != 1:
25
+ u = u.contiguous()
26
+ if delta.stride(-1) != 1:
27
+ delta = delta.contiguous()
28
+ if D is not None:
29
+ D = D.contiguous()
30
+ if B.stride(-1) != 1:
31
+ B = B.contiguous()
32
+ if C.stride(-1) != 1:
33
+ C = C.contiguous()
34
+ if z is not None and z.stride(-1) != 1:
35
+ z = z.contiguous()
36
+ if B.dim() == 3:
37
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
38
+ ctx.squeeze_B = True
39
+ if C.dim() == 3:
40
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
41
+ ctx.squeeze_C = True
42
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
43
+ ctx.delta_softplus = delta_softplus
44
+ ctx.has_z = z is not None
45
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
46
+ if not ctx.has_z:
47
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
48
+ return out if not return_last_state else (out, last_state)
49
+ else:
50
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
51
+ out_z = rest[0]
52
+ return out_z if not return_last_state else (out_z, last_state)
53
+
54
+ @staticmethod
55
+ def backward(ctx, dout, *args):
56
+ if not ctx.has_z:
57
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
58
+ z = None
59
+ out = None
60
+ else:
61
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
62
+ if dout.stride(-1) != 1:
63
+ dout = dout.contiguous()
64
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
65
+ # backward of selective_scan_cuda with the backward of chunk).
66
+ # Here we just pass in None and dz will be allocated in the C++ code.
67
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
68
+ u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
69
+ False # option to recompute out_z, not used here
70
+ )
71
+ dz = rest[0] if ctx.has_z else None
72
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
73
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
74
+ return (du, ddelta, dA, dB, dC,
75
+ dD if D is not None else None,
76
+ dz,
77
+ ddelta_bias if delta_bias is not None else None,
78
+ None,
79
+ None)
80
+
81
+
82
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
83
+ return_last_state=False):
84
+ """if return_last_state is True, returns (out, last_state)
85
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
86
+ not considered in the backward pass.
87
+ """
88
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
89
+
90
+
91
+ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
92
+ return_last_state=False):
93
+ """
94
+ u: r(B D L)
95
+ delta: r(B D L)
96
+ A: c(D N) or r(D N)
97
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
98
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
99
+ D: r(D)
100
+ z: r(B D L)
101
+ delta_bias: r(D), fp32
102
+
103
+ out: r(B D L)
104
+ last_state (optional): r(B D dstate) or c(B D dstate)
105
+ """
106
+ dtype_in = u.dtype
107
+ u = u.float()
108
+ delta = delta.float()
109
+ if delta_bias is not None:
110
+ delta = delta + delta_bias[..., None].float()
111
+ if delta_softplus:
112
+ delta = F.softplus(delta)
113
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
114
+ is_variable_B = B.dim() >= 3
115
+ is_variable_C = C.dim() >= 3
116
+ if A.is_complex():
117
+ if is_variable_B:
118
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
119
+ if is_variable_C:
120
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
121
+ else:
122
+ B = B.float()
123
+ C = C.float()
124
+ x = A.new_zeros((batch, dim, dstate))
125
+ ys = []
126
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
127
+ if not is_variable_B:
128
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
129
+ else:
130
+ if B.dim() == 3:
131
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
132
+ else:
133
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
134
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
135
+ if is_variable_C and C.dim() == 4:
136
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
137
+ last_state = None
138
+ for i in range(u.shape[2]):
139
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
140
+ if not is_variable_C:
141
+ y = torch.einsum('bdn,dn->bd', x, C)
142
+ else:
143
+ if C.dim() == 3:
144
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
145
+ else:
146
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
147
+ if i == u.shape[2] - 1:
148
+ last_state = x
149
+ if y.is_complex():
150
+ y = y.real * 2
151
+ ys.append(y)
152
+ y = torch.stack(ys, dim=2) # (batch dim L)
153
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
154
+ if z is not None:
155
+ out = out * F.silu(z)
156
+ out = out.to(dtype=dtype_in)
157
+ return out if not return_last_state else (out, last_state)
158
+
159
+
160
+ class MambaInnerFn(torch.autograd.Function):
161
+
162
+ @staticmethod
163
+ @custom_fwd
164
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
165
+ out_proj_weight, out_proj_bias,
166
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
167
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
168
+ """
169
+ xz: (batch, dim, seqlen)
170
+ """
171
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
172
+ assert checkpoint_lvl in [0, 1]
173
+ L = xz.shape[-1]
174
+ delta_rank = delta_proj_weight.shape[1]
175
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
176
+ if torch.is_autocast_enabled():
177
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
178
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
179
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
180
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
181
+ if out_proj_bias is not None else None)
182
+ if xz.stride(-1) != 1:
183
+ xz = xz.contiguous()
184
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
185
+ x, z = xz.chunk(2, dim=1)
186
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
187
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
188
+ x, conv1d_weight, conv1d_bias, None, None, None, True
189
+ )
190
+ # We're being very careful here about the layout, to avoid extra transposes.
191
+ # We want delta to have d as the slowest moving dimension
192
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
193
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
194
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
195
+ ctx.is_variable_B = B is None
196
+ ctx.is_variable_C = C is None
197
+ ctx.B_proj_bias_is_None = B_proj_bias is None
198
+ ctx.C_proj_bias_is_None = C_proj_bias is None
199
+ if B is None: # variable B
200
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
201
+ if B_proj_bias is not None:
202
+ B = B + B_proj_bias.to(dtype=B.dtype)
203
+ if not A.is_complex():
204
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
205
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
206
+ else:
207
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
208
+ else:
209
+ if B.stride(-1) != 1:
210
+ B = B.contiguous()
211
+ if C is None: # variable C
212
+ C = x_dbl[:, -d_state:] # (bl dstate)
213
+ if C_proj_bias is not None:
214
+ C = C + C_proj_bias.to(dtype=C.dtype)
215
+ if not A.is_complex():
216
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
217
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
218
+ else:
219
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
220
+ else:
221
+ if C.stride(-1) != 1:
222
+ C = C.contiguous()
223
+ if D is not None:
224
+ D = D.contiguous()
225
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
226
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
227
+ )
228
+ ctx.delta_softplus = delta_softplus
229
+ ctx.out_proj_bias_is_None = out_proj_bias is None
230
+ ctx.checkpoint_lvl = checkpoint_lvl
231
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
232
+ conv1d_out, delta = None, None
233
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
234
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
235
+ A, B, C, D, delta_bias, scan_intermediates, out)
236
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
237
+
238
+ @staticmethod
239
+ @custom_bwd
240
+ def backward(ctx, dout):
241
+ # dout: (batch, seqlen, dim)
242
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
243
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
244
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
245
+ L = xz.shape[-1]
246
+ delta_rank = delta_proj_weight.shape[1]
247
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
248
+ x, z = xz.chunk(2, dim=1)
249
+ if dout.stride(-1) != 1:
250
+ dout = dout.contiguous()
251
+ if ctx.checkpoint_lvl == 1:
252
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
253
+ x, conv1d_weight, conv1d_bias, None, None, None, True
254
+ )
255
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
256
+ "d (b l) -> b d l", l = L)
257
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
258
+ # backward of selective_scan_cuda with the backward of chunk).
259
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
260
+ dx, dz = dxz.chunk(2, dim=1)
261
+ dout = rearrange(dout, "b l e -> e (b l)")
262
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
263
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
264
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
265
+ ctx.delta_softplus,
266
+ True # option to recompute out_z
267
+ )
268
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
269
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
270
+ dD = dD if D is not None else None
271
+ dx_dbl = torch.empty_like(x_dbl)
272
+ dB_proj_bias = None
273
+ if ctx.is_variable_B:
274
+ if not A.is_complex():
275
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
276
+ else:
277
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
278
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
279
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
280
+ dB = None
281
+ dC_proj_bias = None
282
+ if ctx.is_variable_C:
283
+ if not A.is_complex():
284
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
285
+ else:
286
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
287
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
288
+ dx_dbl[:, -d_state:] = dC # (bl d)
289
+ dC = None
290
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
291
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
292
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
293
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
294
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
295
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
296
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
297
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
298
+ # backward of conv1d with the backward of chunk).
299
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
300
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
301
+ )
302
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
303
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
304
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
305
+ dout_proj_weight, dout_proj_bias,
306
+ dA, dB, dC, dD,
307
+ ddelta_bias if delta_bias is not None else None,
308
+ dB_proj_bias, dC_proj_bias, None)
309
+
310
+
311
+ def mamba_inner_fn(
312
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
313
+ out_proj_weight, out_proj_bias,
314
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
315
+ C_proj_bias=None, delta_softplus=True
316
+ ):
317
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
318
+ out_proj_weight, out_proj_bias,
319
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
320
+
321
+
322
+ def mamba_inner_ref(
323
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
324
+ out_proj_weight, out_proj_bias,
325
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
326
+ C_proj_bias=None, delta_softplus=True
327
+ ):
328
+ assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
329
+ L = xz.shape[-1]
330
+ delta_rank = delta_proj_weight.shape[1]
331
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
332
+ x, z = xz.chunk(2, dim=1)
333
+ x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
334
+ # We're being very careful here about the layout, to avoid extra transposes.
335
+ # We want delta to have d as the slowest moving dimension
336
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
337
+ x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
338
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
339
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
340
+ if B is None: # variable B
341
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
342
+ if B_proj_bias is not None:
343
+ B = B + B_proj_bias.to(dtype=B.dtype)
344
+ if not A.is_complex():
345
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
346
+ else:
347
+ B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
348
+ if C is None: # variable B
349
+ C = x_dbl[:, -d_state:] # (bl d)
350
+ if C_proj_bias is not None:
351
+ C = C + C_proj_bias.to(dtype=C.dtype)
352
+ if not A.is_complex():
353
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
354
+ else:
355
+ C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
356
+ y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
357
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
mamba_ssm/ops/triton/__init__.py ADDED
File without changes
mamba_ssm/ops/triton/layernorm.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Implement residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.cuda.amp import custom_fwd, custom_bwd
14
+
15
+ import triton
16
+ import triton.language as tl
17
+
18
+
19
+ def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
20
+ dtype = x.dtype
21
+ if upcast:
22
+ weight = weight.float()
23
+ bias = bias.float() if bias is not None else None
24
+ if upcast:
25
+ x = x.float()
26
+ residual = residual.float() if residual is not None else residual
27
+ if residual is not None:
28
+ x = (x + residual).to(x.dtype)
29
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
30
+ dtype
31
+ )
32
+ return out if not prenorm else (out, x)
33
+
34
+
35
+ def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
36
+ dtype = x.dtype
37
+ if upcast:
38
+ weight = weight.float()
39
+ bias = bias.float() if bias is not None else None
40
+ if upcast:
41
+ x = x.float()
42
+ residual = residual.float() if residual is not None else residual
43
+ if residual is not None:
44
+ x = (x + residual).to(x.dtype)
45
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
46
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
47
+ out = out.to(dtype)
48
+ return out if not prenorm else (out, x)
49
+
50
+
51
+ @triton.autotune(
52
+ configs=[
53
+ triton.Config({}, num_warps=1),
54
+ triton.Config({}, num_warps=2),
55
+ triton.Config({}, num_warps=4),
56
+ triton.Config({}, num_warps=8),
57
+ triton.Config({}, num_warps=16),
58
+ triton.Config({}, num_warps=32),
59
+ ],
60
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
61
+ )
62
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
63
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
64
+ @triton.jit
65
+ def _layer_norm_fwd_1pass_kernel(
66
+ X, # pointer to the input
67
+ Y, # pointer to the output
68
+ W, # pointer to the weights
69
+ B, # pointer to the biases
70
+ RESIDUAL, # pointer to the residual
71
+ RESIDUAL_OUT, # pointer to the residual
72
+ Mean, # pointer to the mean
73
+ Rstd, # pointer to the 1/std
74
+ stride_x_row, # how much to increase the pointer when moving by 1 row
75
+ stride_y_row,
76
+ stride_res_row,
77
+ stride_res_out_row,
78
+ N, # number of columns in X
79
+ eps, # epsilon to avoid division by zero
80
+ IS_RMS_NORM: tl.constexpr,
81
+ BLOCK_N: tl.constexpr,
82
+ HAS_RESIDUAL: tl.constexpr,
83
+ STORE_RESIDUAL_OUT: tl.constexpr,
84
+ HAS_BIAS: tl.constexpr,
85
+ ):
86
+ # Map the program id to the row of X and Y it should compute.
87
+ row = tl.program_id(0)
88
+ X += row * stride_x_row
89
+ Y += row * stride_y_row
90
+ if HAS_RESIDUAL:
91
+ RESIDUAL += row * stride_res_row
92
+ if STORE_RESIDUAL_OUT:
93
+ RESIDUAL_OUT += row * stride_res_out_row
94
+ # Compute mean and variance
95
+ cols = tl.arange(0, BLOCK_N)
96
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
97
+ if HAS_RESIDUAL:
98
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
99
+ x += residual
100
+ if STORE_RESIDUAL_OUT:
101
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
102
+ if not IS_RMS_NORM:
103
+ mean = tl.sum(x, axis=0) / N
104
+ tl.store(Mean + row, mean)
105
+ xbar = tl.where(cols < N, x - mean, 0.0)
106
+ var = tl.sum(xbar * xbar, axis=0) / N
107
+ else:
108
+ xbar = tl.where(cols < N, x, 0.0)
109
+ var = tl.sum(xbar * xbar, axis=0) / N
110
+ rstd = 1 / tl.sqrt(var + eps)
111
+ tl.store(Rstd + row, rstd)
112
+ # Normalize and apply linear transformation
113
+ mask = cols < N
114
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
115
+ if HAS_BIAS:
116
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
117
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
118
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
119
+ # Write output
120
+ tl.store(Y + cols, y, mask=mask)
121
+
122
+
123
+ def _layer_norm_fwd(
124
+ x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
125
+ ):
126
+ if residual is not None:
127
+ residual_dtype = residual.dtype
128
+ M, N = x.shape
129
+ assert x.stride(-1) == 1
130
+ if residual is not None:
131
+ assert residual.stride(-1) == 1
132
+ assert residual.shape == (M, N)
133
+ assert weight.shape == (N,)
134
+ assert weight.stride(-1) == 1
135
+ if bias is not None:
136
+ assert bias.stride(-1) == 1
137
+ assert bias.shape == (N,)
138
+ # allocate output
139
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
140
+ assert y.stride(-1) == 1
141
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
142
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
143
+ assert residual_out.stride(-1) == 1
144
+ else:
145
+ residual_out = None
146
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
147
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
148
+ # Less than 64KB per feature: enqueue fused kernel
149
+ MAX_FUSED_SIZE = 65536 // x.element_size()
150
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
151
+ if N > BLOCK_N:
152
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
153
+ # heuristics for number of warps
154
+ with torch.cuda.device(x.device.index):
155
+ _layer_norm_fwd_1pass_kernel[(M,)](
156
+ x,
157
+ y,
158
+ weight,
159
+ bias,
160
+ residual,
161
+ residual_out,
162
+ mean,
163
+ rstd,
164
+ x.stride(0),
165
+ y.stride(0),
166
+ residual.stride(0) if residual is not None else 0,
167
+ residual_out.stride(0) if residual_out is not None else 0,
168
+ N,
169
+ eps,
170
+ is_rms_norm,
171
+ BLOCK_N,
172
+ residual is not None,
173
+ residual_out is not None,
174
+ bias is not None,
175
+ )
176
+ # residual_out is None if residual is None and residual_dtype == input_dtype
177
+ return y, mean, rstd, residual_out if residual_out is not None else x
178
+
179
+
180
+ @triton.autotune(
181
+ configs=[
182
+ triton.Config({}, num_warps=1),
183
+ triton.Config({}, num_warps=2),
184
+ triton.Config({}, num_warps=4),
185
+ triton.Config({}, num_warps=8),
186
+ triton.Config({}, num_warps=16),
187
+ triton.Config({}, num_warps=32),
188
+ ],
189
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
190
+ )
191
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
192
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
193
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
194
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
195
+ @triton.jit
196
+ def _layer_norm_bwd_kernel(
197
+ X, # pointer to the input
198
+ W, # pointer to the weights
199
+ B, # pointer to the biases
200
+ Y, # pointer to the output to be recomputed
201
+ DY, # pointer to the output gradient
202
+ DX, # pointer to the input gradient
203
+ DW, # pointer to the partial sum of weights gradient
204
+ DB, # pointer to the partial sum of biases gradient
205
+ DRESIDUAL,
206
+ DRESIDUAL_IN,
207
+ Mean, # pointer to the mean
208
+ Rstd, # pointer to the 1/std
209
+ stride_x_row, # how much to increase the pointer when moving by 1 row
210
+ stride_y_row,
211
+ stride_dy_row,
212
+ stride_dx_row,
213
+ stride_dres_row,
214
+ stride_dres_in_row,
215
+ M, # number of rows in X
216
+ N, # number of columns in X
217
+ eps, # epsilon to avoid division by zero
218
+ rows_per_program,
219
+ IS_RMS_NORM: tl.constexpr,
220
+ BLOCK_N: tl.constexpr,
221
+ HAS_DRESIDUAL: tl.constexpr,
222
+ STORE_DRESIDUAL: tl.constexpr,
223
+ HAS_BIAS: tl.constexpr,
224
+ RECOMPUTE_OUTPUT: tl.constexpr,
225
+ ):
226
+ # Map the program id to the elements of X, DX, and DY it should compute.
227
+ row_block_id = tl.program_id(0)
228
+ row_start = row_block_id * rows_per_program
229
+ cols = tl.arange(0, BLOCK_N)
230
+ mask = cols < N
231
+ X += row_start * stride_x_row
232
+ if HAS_DRESIDUAL:
233
+ DRESIDUAL += row_start * stride_dres_row
234
+ if STORE_DRESIDUAL:
235
+ DRESIDUAL_IN += row_start * stride_dres_in_row
236
+ DY += row_start * stride_dy_row
237
+ DX += row_start * stride_dx_row
238
+ if RECOMPUTE_OUTPUT:
239
+ Y += row_start * stride_y_row
240
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
241
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
242
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
243
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
244
+ if HAS_BIAS:
245
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
246
+ row_end = min((row_block_id + 1) * rows_per_program, M)
247
+ for row in range(row_start, row_end):
248
+ # Load data to SRAM
249
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
250
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
251
+ if not IS_RMS_NORM:
252
+ mean = tl.load(Mean + row)
253
+ rstd = tl.load(Rstd + row)
254
+ # Compute dx
255
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
256
+ xhat = tl.where(mask, xhat, 0.0)
257
+ if RECOMPUTE_OUTPUT:
258
+ y = xhat * w + b if HAS_BIAS else xhat * w
259
+ tl.store(Y + cols, y, mask=mask)
260
+ wdy = w * dy
261
+ dw += dy * xhat
262
+ if HAS_BIAS:
263
+ db += dy
264
+ if not IS_RMS_NORM:
265
+ c1 = tl.sum(xhat * wdy, axis=0) / N
266
+ c2 = tl.sum(wdy, axis=0) / N
267
+ dx = (wdy - (xhat * c1 + c2)) * rstd
268
+ else:
269
+ c1 = tl.sum(xhat * wdy, axis=0) / N
270
+ dx = (wdy - xhat * c1) * rstd
271
+ if HAS_DRESIDUAL:
272
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
273
+ dx += dres
274
+ # Write dx
275
+ if STORE_DRESIDUAL:
276
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
277
+ tl.store(DX + cols, dx, mask=mask)
278
+
279
+ X += stride_x_row
280
+ if HAS_DRESIDUAL:
281
+ DRESIDUAL += stride_dres_row
282
+ if STORE_DRESIDUAL:
283
+ DRESIDUAL_IN += stride_dres_in_row
284
+ if RECOMPUTE_OUTPUT:
285
+ Y += stride_y_row
286
+ DY += stride_dy_row
287
+ DX += stride_dx_row
288
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
289
+ if HAS_BIAS:
290
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
291
+
292
+
293
+ def _layer_norm_bwd(
294
+ dy,
295
+ x,
296
+ weight,
297
+ bias,
298
+ eps,
299
+ mean,
300
+ rstd,
301
+ dresidual=None,
302
+ has_residual=False,
303
+ is_rms_norm=False,
304
+ x_dtype=None,
305
+ recompute_output=False,
306
+ ):
307
+ M, N = x.shape
308
+ assert x.stride(-1) == 1
309
+ assert dy.stride(-1) == 1
310
+ assert dy.shape == (M, N)
311
+ if dresidual is not None:
312
+ assert dresidual.stride(-1) == 1
313
+ assert dresidual.shape == (M, N)
314
+ assert weight.shape == (N,)
315
+ assert weight.stride(-1) == 1
316
+ if bias is not None:
317
+ assert bias.stride(-1) == 1
318
+ assert bias.shape == (N,)
319
+ # allocate output
320
+ dx = (
321
+ torch.empty_like(x)
322
+ if x_dtype is None
323
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
324
+ )
325
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
326
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
327
+
328
+ # Less than 64KB per feature: enqueue fused kernel
329
+ MAX_FUSED_SIZE = 65536 // x.element_size()
330
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
331
+ if N > BLOCK_N:
332
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
333
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
334
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
335
+ _db = (
336
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
337
+ if bias is not None
338
+ else None
339
+ )
340
+ rows_per_program = math.ceil(M / sm_count)
341
+ grid = (sm_count,)
342
+ with torch.cuda.device(x.device.index):
343
+ _layer_norm_bwd_kernel[grid](
344
+ x,
345
+ weight,
346
+ bias,
347
+ y,
348
+ dy,
349
+ dx,
350
+ _dw,
351
+ _db,
352
+ dresidual,
353
+ dresidual_in,
354
+ mean,
355
+ rstd,
356
+ x.stride(0),
357
+ 0 if not recompute_output else y.stride(0),
358
+ dy.stride(0),
359
+ dx.stride(0),
360
+ dresidual.stride(0) if dresidual is not None else 0,
361
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
362
+ M,
363
+ N,
364
+ eps,
365
+ rows_per_program,
366
+ is_rms_norm,
367
+ BLOCK_N,
368
+ dresidual is not None,
369
+ dresidual_in is not None,
370
+ bias is not None,
371
+ )
372
+ dw = _dw.sum(0).to(weight.dtype)
373
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
374
+ # Don't need to compute dresidual_in separately in this case
375
+ if has_residual and dx.dtype == x.dtype:
376
+ dresidual_in = dx
377
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
378
+
379
+
380
+ class LayerNormFn(torch.autograd.Function):
381
+ @staticmethod
382
+ def forward(
383
+ ctx,
384
+ x,
385
+ weight,
386
+ bias,
387
+ residual=None,
388
+ eps=1e-6,
389
+ prenorm=False,
390
+ residual_in_fp32=False,
391
+ is_rms_norm=False,
392
+ ):
393
+ x_shape_og = x.shape
394
+ # reshape input data into 2D tensor
395
+ x = x.reshape(-1, x.shape[-1])
396
+ if x.stride(-1) != 1:
397
+ x = x.contiguous()
398
+ if residual is not None:
399
+ assert residual.shape == x_shape_og
400
+ residual = residual.reshape(-1, residual.shape[-1])
401
+ if residual.stride(-1) != 1:
402
+ residual = residual.contiguous()
403
+ weight = weight.contiguous()
404
+ if bias is not None:
405
+ bias = bias.contiguous()
406
+ residual_dtype = (
407
+ residual.dtype
408
+ if residual is not None
409
+ else (torch.float32 if residual_in_fp32 else None)
410
+ )
411
+ y, mean, rstd, residual_out = _layer_norm_fwd(
412
+ x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
413
+ )
414
+ ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
415
+ ctx.x_shape_og = x_shape_og
416
+ ctx.eps = eps
417
+ ctx.is_rms_norm = is_rms_norm
418
+ ctx.has_residual = residual is not None
419
+ ctx.prenorm = prenorm
420
+ ctx.x_dtype = x.dtype
421
+ y = y.reshape(x_shape_og)
422
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
423
+
424
+ @staticmethod
425
+ def backward(ctx, dy, *args):
426
+ x, weight, bias, mean, rstd = ctx.saved_tensors
427
+ dy = dy.reshape(-1, dy.shape[-1])
428
+ if dy.stride(-1) != 1:
429
+ dy = dy.contiguous()
430
+ assert dy.shape == x.shape
431
+ if ctx.prenorm:
432
+ dresidual = args[0]
433
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
434
+ if dresidual.stride(-1) != 1:
435
+ dresidual = dresidual.contiguous()
436
+ assert dresidual.shape == x.shape
437
+ else:
438
+ dresidual = None
439
+ dx, dw, db, dresidual_in = _layer_norm_bwd(
440
+ dy,
441
+ x,
442
+ weight,
443
+ bias,
444
+ ctx.eps,
445
+ mean,
446
+ rstd,
447
+ dresidual,
448
+ ctx.has_residual,
449
+ ctx.is_rms_norm,
450
+ x_dtype=ctx.x_dtype,
451
+ )
452
+ return (
453
+ dx.reshape(ctx.x_shape_og),
454
+ dw,
455
+ db,
456
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
457
+ None,
458
+ None,
459
+ None,
460
+ None,
461
+ )
462
+
463
+
464
+ def layer_norm_fn(
465
+ x,
466
+ weight,
467
+ bias,
468
+ residual=None,
469
+ eps=1e-6,
470
+ prenorm=False,
471
+ residual_in_fp32=False,
472
+ is_rms_norm=False,
473
+ ):
474
+ return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
475
+
476
+
477
+ def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
478
+ return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
479
+
480
+
481
+ class RMSNorm(torch.nn.Module):
482
+ def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
483
+ factory_kwargs = {"device": device, "dtype": dtype}
484
+ super().__init__()
485
+ self.eps = eps
486
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
487
+ self.register_parameter("bias", None)
488
+ self.reset_parameters()
489
+
490
+ def reset_parameters(self):
491
+ torch.nn.init.ones_(self.weight)
492
+
493
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
494
+ return rms_norm_fn(
495
+ x,
496
+ self.weight,
497
+ self.bias,
498
+ residual=residual,
499
+ eps=self.eps,
500
+ prenorm=prenorm,
501
+ residual_in_fp32=residual_in_fp32,
502
+ )
503
+
504
+
505
+ class LayerNormLinearFn(torch.autograd.Function):
506
+ @staticmethod
507
+ @custom_fwd
508
+ def forward(
509
+ ctx,
510
+ x,
511
+ norm_weight,
512
+ norm_bias,
513
+ linear_weight,
514
+ linear_bias,
515
+ residual=None,
516
+ eps=1e-6,
517
+ prenorm=False,
518
+ residual_in_fp32=False,
519
+ is_rms_norm=False,
520
+ ):
521
+ x_shape_og = x.shape
522
+ # reshape input data into 2D tensor
523
+ x = x.reshape(-1, x.shape[-1])
524
+ if x.stride(-1) != 1:
525
+ x = x.contiguous()
526
+ if residual is not None:
527
+ assert residual.shape == x_shape_og
528
+ residual = residual.reshape(-1, residual.shape[-1])
529
+ if residual.stride(-1) != 1:
530
+ residual = residual.contiguous()
531
+ norm_weight = norm_weight.contiguous()
532
+ if norm_bias is not None:
533
+ norm_bias = norm_bias.contiguous()
534
+ residual_dtype = (
535
+ residual.dtype
536
+ if residual is not None
537
+ else (torch.float32 if residual_in_fp32 else None)
538
+ )
539
+ y, mean, rstd, residual_out = _layer_norm_fwd(
540
+ x,
541
+ norm_weight,
542
+ norm_bias,
543
+ eps,
544
+ residual,
545
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
546
+ residual_dtype=residual_dtype,
547
+ is_rms_norm=is_rms_norm,
548
+ )
549
+ y = y.reshape(x_shape_og)
550
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
551
+ linear_weight = linear_weight.to(dtype)
552
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
553
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
554
+ # We don't store y, will be recomputed in the backward pass to save memory
555
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
556
+ ctx.x_shape_og = x_shape_og
557
+ ctx.eps = eps
558
+ ctx.is_rms_norm = is_rms_norm
559
+ ctx.has_residual = residual is not None
560
+ ctx.prenorm = prenorm
561
+ ctx.x_dtype = x.dtype
562
+ ctx.linear_bias_is_none = linear_bias is None
563
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
564
+
565
+ @staticmethod
566
+ @custom_bwd
567
+ def backward(ctx, dout, *args):
568
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
569
+ dout = dout.reshape(-1, dout.shape[-1])
570
+ dy = F.linear(dout, linear_weight.t())
571
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
572
+ if dy.stride(-1) != 1:
573
+ dy = dy.contiguous()
574
+ assert dy.shape == x.shape
575
+ if ctx.prenorm:
576
+ dresidual = args[0]
577
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
578
+ if dresidual.stride(-1) != 1:
579
+ dresidual = dresidual.contiguous()
580
+ assert dresidual.shape == x.shape
581
+ else:
582
+ dresidual = None
583
+ dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
584
+ dy,
585
+ x,
586
+ norm_weight,
587
+ norm_bias,
588
+ ctx.eps,
589
+ mean,
590
+ rstd,
591
+ dresidual,
592
+ ctx.has_residual,
593
+ ctx.is_rms_norm,
594
+ x_dtype=ctx.x_dtype,
595
+ recompute_output=True,
596
+ )
597
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
598
+ return (
599
+ dx.reshape(ctx.x_shape_og),
600
+ dnorm_weight,
601
+ dnorm_bias,
602
+ dlinear_weight,
603
+ dlinear_bias,
604
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
605
+ None,
606
+ None,
607
+ None,
608
+ None,
609
+ )
610
+
611
+
612
+ def layer_norm_linear_fn(
613
+ x,
614
+ norm_weight,
615
+ norm_bias,
616
+ linear_weight,
617
+ linear_bias,
618
+ residual=None,
619
+ eps=1e-6,
620
+ prenorm=False,
621
+ residual_in_fp32=False,
622
+ is_rms_norm=False,
623
+ ):
624
+ return LayerNormLinearFn.apply(
625
+ x,
626
+ norm_weight,
627
+ norm_bias,
628
+ linear_weight,
629
+ linear_bias,
630
+ residual,
631
+ eps,
632
+ prenorm,
633
+ residual_in_fp32,
634
+ is_rms_norm,
635
+ )
mamba_ssm/ops/triton/selective_state_update.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
17
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
18
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
19
+ @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
20
+ @triton.jit
21
+ def _selective_scan_update_kernel(
22
+ # Pointers to matrices
23
+ state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
24
+ # Matrix dimensions
25
+ batch, nheads, dim, dstate, nheads_ngroups_ratio,
26
+ # Strides
27
+ stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
28
+ stride_x_batch, stride_x_head, stride_x_dim,
29
+ stride_dt_batch, stride_dt_head, stride_dt_dim,
30
+ stride_dt_bias_head, stride_dt_bias_dim,
31
+ stride_A_head, stride_A_dim, stride_A_dstate,
32
+ stride_B_batch, stride_B_group, stride_B_dstate,
33
+ stride_C_batch, stride_C_group, stride_C_dstate,
34
+ stride_D_head, stride_D_dim,
35
+ stride_z_batch, stride_z_head, stride_z_dim,
36
+ stride_out_batch, stride_out_head, stride_out_dim,
37
+ # Meta-parameters
38
+ DT_SOFTPLUS: tl.constexpr,
39
+ TIE_HDIM: tl.constexpr,
40
+ BLOCK_SIZE_M: tl.constexpr,
41
+ HAS_DT_BIAS: tl.constexpr,
42
+ HAS_D: tl.constexpr,
43
+ HAS_Z: tl.constexpr,
44
+ BLOCK_SIZE_DSTATE: tl.constexpr,
45
+ ):
46
+ pid_m = tl.program_id(axis=0)
47
+ pid_b = tl.program_id(axis=1)
48
+ pid_h = tl.program_id(axis=2)
49
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
50
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
51
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
52
+ if HAS_DT_BIAS:
53
+ dt_bias_ptr += pid_h * stride_dt_bias_head
54
+ A_ptr += pid_h * stride_A_head
55
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
56
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
57
+ if HAS_Z:
58
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
59
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
60
+
61
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
62
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
63
+ state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
64
+ x_ptrs = x_ptr + offs_m * stride_x_dim
65
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
66
+ if HAS_DT_BIAS:
67
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
68
+ if HAS_D:
69
+ D_ptr += pid_h * stride_D_head
70
+ A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
71
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
72
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
73
+ if HAS_D:
74
+ D_ptrs = D_ptr + offs_m * stride_D_dim
75
+ if HAS_Z:
76
+ z_ptrs = z_ptr + offs_m * stride_z_dim
77
+ out_ptrs = out_ptr + offs_m * stride_out_dim
78
+
79
+ state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
80
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
81
+ if not TIE_HDIM:
82
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
83
+ if HAS_DT_BIAS:
84
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85
+ if DT_SOFTPLUS:
86
+ dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
87
+ A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
88
+ dA = tl.exp(A * dt[:, None])
89
+ else:
90
+ dt = tl.load(dt_ptr).to(tl.float32)
91
+ if HAS_DT_BIAS:
92
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
93
+ if DT_SOFTPLUS:
94
+ dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
95
+ A = tl.load(A_ptr).to(tl.float32)
96
+ dA = tl.exp(A * dt) # scalar, not a matrix
97
+
98
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
99
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
100
+ if HAS_D:
101
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
102
+ if HAS_Z:
103
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
104
+
105
+ if not TIE_HDIM:
106
+ dB = B[None, :] * dt[:, None]
107
+ else:
108
+ dB = B * dt # vector of size (dstate,)
109
+ state = state * dA + dB * x[:, None]
110
+ tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
111
+ out = tl.sum(state * C[None, :], axis=1)
112
+ if HAS_D:
113
+ out += x * D
114
+ if HAS_Z:
115
+ out *= z * tl.sigmoid(z)
116
+ tl.store(out_ptrs, out, mask=offs_m < dim)
117
+
118
+
119
+ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
120
+ """
121
+ Argument:
122
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
123
+ x: (batch, dim) or (batch, nheads, dim)
124
+ dt: (batch, dim) or (batch, nheads, dim)
125
+ A: (dim, dstate) or (nheads, dim, dstate)
126
+ B: (batch, dstate) or (batch, ngroups, dstate)
127
+ C: (batch, dstate) or (batch, ngroups, dstate)
128
+ D: (dim,) or (nheads, dim)
129
+ z: (batch, dim) or (batch, nheads, dim)
130
+ dt_bias: (dim,) or (nheads, dim)
131
+ Return:
132
+ out: (batch, dim) or (batch, nheads, dim)
133
+ """
134
+ has_heads = state.dim() > 3
135
+ if state.dim() == 3:
136
+ state = state.unsqueeze(1)
137
+ if x.dim() == 2:
138
+ x = x.unsqueeze(1)
139
+ if dt.dim() == 2:
140
+ dt = dt.unsqueeze(1)
141
+ if A.dim() == 2:
142
+ A = A.unsqueeze(0)
143
+ if B.dim() == 2:
144
+ B = B.unsqueeze(1)
145
+ if C.dim() == 2:
146
+ C = C.unsqueeze(1)
147
+ if D is not None and D.dim() == 1:
148
+ D = D.unsqueeze(0)
149
+ if z is not None and z.dim() == 2:
150
+ z = z.unsqueeze(1)
151
+ if dt_bias is not None and dt_bias.dim() == 1:
152
+ dt_bias = dt_bias.unsqueeze(0)
153
+ batch, nheads, dim, dstate = state.shape
154
+ assert x.shape == (batch, nheads, dim)
155
+ assert dt.shape == x.shape
156
+ assert A.shape == (nheads, dim, dstate)
157
+ ngroups = B.shape[1]
158
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
159
+ assert B.shape == (batch, ngroups, dstate)
160
+ assert C.shape == B.shape
161
+ if D is not None:
162
+ assert D.shape == (nheads, dim)
163
+ if z is not None:
164
+ assert z.shape == x.shape
165
+ if dt_bias is not None:
166
+ assert dt_bias.shape == (nheads, dim)
167
+ out = torch.empty_like(x)
168
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
169
+ z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
170
+ # We don't want autotune since it will overwrite the state
171
+ # We instead tune by hand.
172
+ BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
173
+ else ((16, 4) if dstate <= 32 else
174
+ ((8, 4) if dstate <= 64 else
175
+ ((4, 4) if dstate <= 128 else
176
+ ((4, 8))))))
177
+ tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
178
+ with torch.cuda.device(x.device.index):
179
+ _selective_scan_update_kernel[grid](
180
+ state, x, dt, dt_bias, A, B, C, D, z, out,
181
+ batch, nheads, dim, dstate, nheads // ngroups,
182
+ state.stride(0), state.stride(1), state.stride(2), state.stride(3),
183
+ x.stride(0), x.stride(1), x.stride(2),
184
+ dt.stride(0), dt.stride(1), dt.stride(2),
185
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
186
+ A.stride(0), A.stride(1), A.stride(2),
187
+ B.stride(0), B.stride(1), B.stride(2),
188
+ C.stride(0), C.stride(1), C.stride(2),
189
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
190
+ z_strides[0], z_strides[1], z_strides[2],
191
+ out.stride(0), out.stride(1), out.stride(2),
192
+ dt_softplus,
193
+ tie_hdim,
194
+ BLOCK_SIZE_M,
195
+ num_warps=num_warps,
196
+ )
197
+ if not has_heads:
198
+ out = out.squeeze(1)
199
+ return out
200
+
201
+
202
+ def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
203
+ """
204
+ Argument:
205
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
206
+ x: (batch, dim) or (batch, nheads, dim)
207
+ dt: (batch, dim) or (batch, nheads, dim)
208
+ A: (dim, dstate) or (nheads, dim, dstate)
209
+ B: (batch, dstate) or (batch, ngroups, dstate)
210
+ C: (batch, dstate) or (batch, ngroups, dstate)
211
+ D: (dim,) or (nheads, dim)
212
+ z: (batch, dim) or (batch, nheads, dim)
213
+ dt_bias: (dim,) or (nheads, dim)
214
+ Return:
215
+ out: (batch, dim) or (batch, nheads, dim)
216
+ """
217
+ has_heads = state.dim() > 3
218
+ if state.dim() == 3:
219
+ state = state.unsqueeze(1)
220
+ if x.dim() == 2:
221
+ x = x.unsqueeze(1)
222
+ if dt.dim() == 2:
223
+ dt = dt.unsqueeze(1)
224
+ if A.dim() == 2:
225
+ A = A.unsqueeze(0)
226
+ if B.dim() == 2:
227
+ B = B.unsqueeze(1)
228
+ if C.dim() == 2:
229
+ C = C.unsqueeze(1)
230
+ if D is not None and D.dim() == 1:
231
+ D = D.unsqueeze(0)
232
+ if z is not None and z.dim() == 2:
233
+ z = z.unsqueeze(1)
234
+ if dt_bias is not None and dt_bias.dim() == 1:
235
+ dt_bias = dt_bias.unsqueeze(0)
236
+ batch, nheads, dim, dstate = state.shape
237
+ assert x.shape == (batch, nheads, dim)
238
+ assert dt.shape == x.shape
239
+ assert A.shape == (nheads, dim, dstate)
240
+ ngroups = B.shape[1]
241
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
242
+ assert B.shape == (batch, ngroups, dstate)
243
+ assert C.shape == B.shape
244
+ if D is not None:
245
+ assert D.shape == (nheads, dim)
246
+ if z is not None:
247
+ assert z.shape == x.shape
248
+ if dt_bias is not None:
249
+ assert dt_bias.shape == (nheads, dim)
250
+ dt = dt + dt_bias
251
+ dt = F.softplus(dt) if dt_softplus else dt
252
+ dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
253
+ B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
254
+ C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
255
+ dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
256
+ state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
257
+ out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
258
+ if D is not None:
259
+ out += (x * D).to(out.dtype)
260
+ out = (out if z is None else out * F.silu(z)).to(x.dtype)
261
+ if not has_heads:
262
+ out = out.squeeze(1)
263
+ return out
mamba_ssm/utils/__init__.py ADDED
File without changes
mamba_ssm/utils/generation.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_p)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ model,
124
+ max_length,
125
+ top_k=1,
126
+ top_p=0.0,
127
+ min_p=0.0,
128
+ temperature=1.0,
129
+ repetition_penalty=1.0,
130
+ eos_token_id=None,
131
+ teacher_outputs=None,
132
+ vocab_size=None,
133
+ cg=False,
134
+ enable_timing=False,
135
+ streamer: Optional[TextStreamer] = None
136
+ ):
137
+ """Decoding, either greedy or with top-k or top-p sampling.
138
+ If top-k = 0, don't limit the number of candidates (pure sampling).
139
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
140
+ then top-p.
141
+ We assume that all sequences in the same batch have the same length.
142
+
143
+ Arguments:
144
+ input_ids: (batch, seq_len)
145
+ max_length: int
146
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
147
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
148
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
149
+ sequences: (batch, max_length)
150
+ scores: tuples of (batch, vocab_size)
151
+ """
152
+ if streamer is not None:
153
+ streamer.put(input_ids.cpu())
154
+
155
+ batch_size, seqlen_og = input_ids.shape
156
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
157
+ if cg:
158
+ if not hasattr(model, "_decoding_cache"):
159
+ model._decoding_cache = None
160
+ model._decoding_cache = update_graph_cache(
161
+ model,
162
+ model._decoding_cache,
163
+ batch_size,
164
+ seqlen_og,
165
+ max_length,
166
+ )
167
+ inference_params = model._decoding_cache.inference_params
168
+ inference_params.reset(max_length, batch_size)
169
+ else:
170
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
171
+
172
+ def get_logits(input_ids, inference_params):
173
+ decoding = inference_params.seqlen_offset > 0
174
+ if decoding:
175
+ position_ids = torch.full(
176
+ (batch_size, 1),
177
+ inference_params.seqlen_offset,
178
+ dtype=torch.long,
179
+ device=input_ids.device,
180
+ )
181
+ else:
182
+ position_ids = None
183
+ if not cg or not decoding:
184
+ logits = model(
185
+ input_ids,
186
+ position_ids=position_ids,
187
+ inference_params=inference_params,
188
+ num_last_tokens=1,
189
+ ).logits.squeeze(dim=1)
190
+ else:
191
+ logits = model._decoding_cache.run(
192
+ input_ids, position_ids, inference_params.seqlen_offset
193
+ ).squeeze(dim=1)
194
+ return logits[..., :vocab_size] if vocab_size is not None else logits
195
+
196
+ def sample_tokens(logits, inference_params):
197
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
198
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
199
+ else:
200
+ token = teacher_outputs[:, inference_params.seqlen_offset]
201
+ # return rearrange(token, "b -> b 1")
202
+ return token.unsqueeze(1)
203
+
204
+ def should_stop(current_token, inference_params):
205
+ if inference_params.seqlen_offset == 0:
206
+ return False
207
+ if eos_token_id is not None and (current_token == eos_token_id).all():
208
+ return True
209
+ if inference_params.seqlen_offset >= max_length - 1:
210
+ return True
211
+ return False
212
+
213
+ start = torch.cuda.Event(enable_timing=enable_timing)
214
+ end = torch.cuda.Event(enable_timing=enable_timing)
215
+
216
+ if enable_timing:
217
+ start.record()
218
+ scores, sequences = [], [input_ids]
219
+ sequences_cat = input_ids
220
+ while not should_stop(sequences[-1], inference_params):
221
+ scores.append(get_logits(sequences[-1], inference_params))
222
+ inference_params.seqlen_offset += sequences[-1].shape[1]
223
+ if repetition_penalty == 1.0:
224
+ sampled_tokens = sample_tokens(scores[-1], inference_params)
225
+ else:
226
+ logits = modify_logit_for_repetition_penalty(
227
+ scores[-1].clone(), sequences_cat, repetition_penalty
228
+ )
229
+ sampled_tokens = sample_tokens(logits, inference_params)
230
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
231
+ sequences.append(sampled_tokens)
232
+ if streamer is not None:
233
+ streamer.put(sampled_tokens.cpu())
234
+ if streamer is not None:
235
+ streamer.end()
236
+ if enable_timing:
237
+ end.record()
238
+ torch.cuda.synchronize()
239
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
240
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
241
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
242
+
243
+
244
+ class GenerationMixin:
245
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
246
+ raise NotImplementedError
247
+
248
+ def generate(
249
+ self,
250
+ input_ids,
251
+ max_length,
252
+ top_k=1,
253
+ top_p=0.0,
254
+ min_p=0.0,
255
+ temperature=1.0,
256
+ return_dict_in_generate=False,
257
+ output_scores=False,
258
+ **kwargs,
259
+ ):
260
+ output = decode(
261
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
262
+ )
263
+ if not output_scores:
264
+ output.scores = None
265
+ return output if return_dict_in_generate else output.sequences
266
+
267
+
268
+ @dataclass
269
+ class DecodingCGCache:
270
+ max_batch_size: int = 0
271
+ max_seqlen: int = 0
272
+ device = None
273
+ dtype = None
274
+ callables: dict = field(default_factory=dict)
275
+ mempool = None
276
+ inference_params: Optional[InferenceParams] = None
277
+ run: Optional[Callable] = None
278
+
279
+
280
+ @torch.inference_mode()
281
+ def update_graph_cache(
282
+ model,
283
+ cache,
284
+ batch_size,
285
+ seqlen_og,
286
+ max_seqlen,
287
+ decoding_seqlens=(1,),
288
+ dtype=None,
289
+ n_warmups=2,
290
+ ):
291
+ if cache is None:
292
+ cache = DecodingCGCache()
293
+ param_example = next(iter(model.parameters()))
294
+ device = param_example.device
295
+ if dtype is None:
296
+ dtype = param_example.dtype
297
+ if (
298
+ (device, dtype) != (cache.device, cache.dtype)
299
+ or batch_size > cache.max_batch_size
300
+ or max_seqlen > cache.max_seqlen
301
+ ): # Invalidate the cache
302
+ cache.callables = {}
303
+ cache.mempool = None
304
+ cache.inference_params = None
305
+ gc.collect()
306
+ cache.device, cache.dtype = device, dtype
307
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
308
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
309
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
310
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
311
+ cache.inference_params = InferenceParams(
312
+ max_seqlen=max_seqlen,
313
+ max_batch_size=batch_size,
314
+ seqlen_offset=seqlen_og,
315
+ key_value_memory_dict=inf_cache,
316
+ lengths_per_sample=lengths_per_sample,
317
+ )
318
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
319
+ for decoding_seqlen in decoding_seqlens:
320
+ if (batch_size, decoding_seqlen) not in cache.callables:
321
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
322
+ model,
323
+ cache.inference_params,
324
+ batch_size,
325
+ max_seqlen,
326
+ decoding_seqlen=decoding_seqlen,
327
+ mempool=cache.mempool,
328
+ n_warmups=n_warmups,
329
+ )
330
+
331
+ def dispatch(input_ids, position_ids, seqlen):
332
+ batch_size, decoding_seqlen = input_ids.shape[:2]
333
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
334
+
335
+ cache.run = dispatch
336
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
337
+ return cache
338
+
339
+
340
+ def capture_graph(
341
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
342
+ ):
343
+ device = next(iter(model.parameters())).device
344
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
345
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
346
+ seqlen_offset_og = inference_params.seqlen_offset
347
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
348
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
349
+
350
+ # Warmup before capture
351
+ s = torch.cuda.Stream()
352
+ s.wait_stream(torch.cuda.current_stream())
353
+ with torch.cuda.stream(s):
354
+ for _ in range(n_warmups):
355
+ logits = model(
356
+ input_ids,
357
+ position_ids=position_ids,
358
+ inference_params=inference_params,
359
+ num_last_tokens=decoding_seqlen,
360
+ ).logits
361
+ s.synchronize()
362
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
363
+ # which requires that graph launch and non-captured launch to not overlap (I think,
364
+ # that's how I interpret the documentation). I'm not sure if this is required.
365
+ if torch.distributed.is_initialized():
366
+ torch.distributed.barrier()
367
+ torch.cuda.current_stream().wait_stream(s)
368
+ # Captures the graph
369
+ # To allow capture, automatically sets a side stream as the current stream in the context
370
+ graph = torch.cuda.CUDAGraph()
371
+ with torch.cuda.graph(graph, pool=mempool):
372
+ logits = model(
373
+ input_ids,
374
+ position_ids=position_ids,
375
+ inference_params=inference_params,
376
+ num_last_tokens=decoding_seqlen,
377
+ ).logits
378
+
379
+ def run(new_input_ids, new_position_ids, seqlen):
380
+ inference_params.lengths_per_sample[:] = seqlen
381
+ input_ids.copy_(new_input_ids)
382
+ position_ids.copy_(new_position_ids)
383
+ graph.replay()
384
+ return logits.clone()
385
+
386
+ inference_params.seqlen_offset = seqlen_offset_og
387
+ return run
mamba_ssm/utils/hf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
+ from transformers.utils.hub import cached_file
7
+
8
+
9
+ def load_config_hf(model_name):
10
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
+ return json.load(open(resolved_archive_file))
12
+
13
+
14
+ def load_state_dict_hf(model_name, device=None, dtype=None):
15
+ # If not fp32, then we don't want to load directly to the GPU
16
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
+ return torch.load(resolved_archive_file, map_location=mapped_device)
19
+ # Convert dtype before moving to GPU to save memory
20
+ if dtype is not None:
21
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
+ return state_dict
models/codec_module.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from .lsigmoid import LearnableSigmoid2D
7
+
8
+ def get_padding(kernel_size, dilation=1):
9
+ """
10
+ Calculate the padding size for a convolutional layer.
11
+
12
+ Args:
13
+ - kernel_size (int): Size of the convolutional kernel.
14
+ - dilation (int, optional): Dilation rate of the convolution. Defaults to 1.
15
+
16
+ Returns:
17
+ - int: Calculated padding size.
18
+ """
19
+ return int((kernel_size * dilation - dilation) / 2)
20
+
21
+ def get_padding_2d(kernel_size, dilation=(1, 1)):
22
+ """
23
+ Calculate the padding size for a 2D convolutional layer.
24
+
25
+ Args:
26
+ - kernel_size (tuple): Size of the convolutional kernel (height, width).
27
+ - dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
28
+
29
+ Returns:
30
+ - tuple: Calculated padding size (height, width).
31
+ """
32
+ return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
33
+ int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
34
+
35
+ class DenseBlock(nn.Module):
36
+ """
37
+ DenseBlock module consisting of multiple convolutional layers with dilation.
38
+ """
39
+ def __init__(self, cfg, kernel_size=(3, 3), depth=4):
40
+ super(DenseBlock, self).__init__()
41
+ self.cfg = cfg
42
+ self.depth = depth
43
+ self.dense_block = nn.ModuleList()
44
+ self.hid_feature = cfg['model_cfg']['hid_feature']
45
+
46
+ for i in range(depth):
47
+ dil = 2 ** i
48
+ dense_conv = nn.Sequential(
49
+ nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
50
+ dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
51
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
52
+ nn.PReLU(self.hid_feature)
53
+ )
54
+ self.dense_block.append(dense_conv)
55
+
56
+ def forward(self, x):
57
+ """
58
+ Forward pass for the DenseBlock module.
59
+
60
+ Args:
61
+ - x (torch.Tensor): Input tensor.
62
+
63
+ Returns:
64
+ - torch.Tensor: Output tensor after processing through the dense block.
65
+ """
66
+ skip = x
67
+ for i in range(self.depth):
68
+ x = self.dense_block[i](skip)
69
+ skip = torch.cat([x, skip], dim=1)
70
+ return x
71
+
72
+ class DenseEncoder(nn.Module):
73
+ """
74
+ DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
75
+ """
76
+ def __init__(self, cfg):
77
+ super(DenseEncoder, self).__init__()
78
+ self.cfg = cfg
79
+ self.input_channel = cfg['model_cfg']['input_channel']
80
+ self.hid_feature = cfg['model_cfg']['hid_feature']
81
+
82
+ self.dense_conv_1 = nn.Sequential(
83
+ nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
84
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
85
+ nn.PReLU(self.hid_feature)
86
+ )
87
+
88
+ self.dense_block = DenseBlock(cfg, depth=4)
89
+
90
+ self.dense_conv_2 = nn.Sequential(
91
+ nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
92
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
93
+ nn.PReLU(self.hid_feature)
94
+ )
95
+
96
+ def forward(self, x):
97
+ """
98
+ Forward pass for the DenseEncoder module.
99
+
100
+ Args:
101
+ - x (torch.Tensor): Input tensor.
102
+
103
+ Returns:
104
+ - torch.Tensor: Encoded tensor.
105
+ """
106
+ x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
107
+ x = self.dense_block(x) # [batch, hid_feature, time, freq]
108
+ x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
109
+ return x
110
+
111
+ class MagDecoder(nn.Module):
112
+ """
113
+ MagDecoder module for decoding magnitude information.
114
+ """
115
+ def __init__(self, cfg):
116
+ super(MagDecoder, self).__init__()
117
+ self.dense_block = DenseBlock(cfg, depth=4)
118
+ self.hid_feature = cfg['model_cfg']['hid_feature']
119
+ self.output_channel = cfg['model_cfg']['output_channel']
120
+ self.n_fft = cfg['stft_cfg']['n_fft']
121
+ self.beta = cfg['model_cfg']['beta']
122
+
123
+ self.mask_conv = nn.Sequential(
124
+ nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
125
+ nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)),
126
+ nn.InstanceNorm2d(self.output_channel, affine=True),
127
+ nn.PReLU(self.output_channel),
128
+ nn.Conv2d(self.output_channel, self.output_channel, (1, 1))
129
+ )
130
+ self.lsigmoid = LearnableSigmoid2D(self.n_fft // 2 + 1, beta=self.beta)
131
+
132
+ def forward(self, x):
133
+ """
134
+ Forward pass for the MagDecoder module.
135
+
136
+ Args:
137
+ - x (torch.Tensor): Input tensor.
138
+
139
+ Returns:
140
+ - torch.Tensor: Decoded tensor with magnitude information.
141
+ """
142
+ x = self.dense_block(x)
143
+ x = self.mask_conv(x)
144
+ x = rearrange(x, 'b c t f -> b f t c').squeeze(-1)
145
+ x = self.lsigmoid(x)
146
+ x = rearrange(x, 'b f t -> b t f').unsqueeze(1)
147
+ return x
148
+
149
+ class PhaseDecoder(nn.Module):
150
+ """
151
+ PhaseDecoder module for decoding phase information.
152
+ """
153
+ def __init__(self, cfg):
154
+ super(PhaseDecoder, self).__init__()
155
+ self.dense_block = DenseBlock(cfg, depth=4)
156
+ self.hid_feature = cfg['model_cfg']['hid_feature']
157
+ self.output_channel = cfg['model_cfg']['output_channel']
158
+
159
+ self.phase_conv = nn.Sequential(
160
+ nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
161
+ nn.InstanceNorm2d(self.hid_feature, affine=True),
162
+ nn.PReLU(self.hid_feature)
163
+ )
164
+
165
+ self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
166
+ self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
167
+
168
+ def forward(self, x):
169
+ """
170
+ Forward pass for the PhaseDecoder module.
171
+
172
+ Args:
173
+ - x (torch.Tensor): Input tensor.
174
+
175
+ Returns:
176
+ - torch.Tensor: Decoded tensor with phase information.
177
+ """
178
+ x = self.dense_block(x)
179
+ x = self.phase_conv(x)
180
+ x_r = self.phase_conv_r(x)
181
+ x_i = self.phase_conv_i(x)
182
+ x = torch.atan2(x_i, x_r)
183
+ return x
models/discriminator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References: https://github.com/yxlu-0102/MP-SENet/blob/main/models/discriminator.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from pesq import pesq
7
+ from joblib import Parallel, delayed
8
+ from models.lsigmoid import LearnableSigmoid1D
9
+
10
+ def pesq_loss(clean, noisy, sr=16000):
11
+ try:
12
+ pesq_score = pesq(sr, clean, noisy, 'wb')
13
+ except:
14
+ # error can happen due to silent period
15
+ pesq_score = -1
16
+ return pesq_score
17
+
18
+
19
+ def batch_pesq(clean, noisy, cfg):
20
+ num_worker = cfg['env_setting']['num_workers']
21
+ pesq_score = Parallel(n_jobs=num_worker)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy))
22
+ pesq_score = np.array(pesq_score)
23
+ if -1 in pesq_score:
24
+ return None
25
+ pesq_score = (pesq_score - 1) / 3.5
26
+ return torch.FloatTensor(pesq_score)
27
+
28
+
29
+ class MetricDiscriminator(nn.Module):
30
+ def __init__(self, dim=16, in_channel=2):
31
+ super(MetricDiscriminator, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
34
+ nn.InstanceNorm2d(dim, affine=True),
35
+ nn.PReLU(dim),
36
+ nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
37
+ nn.InstanceNorm2d(dim*2, affine=True),
38
+ nn.PReLU(dim*2),
39
+ nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
40
+ nn.InstanceNorm2d(dim*4, affine=True),
41
+ nn.PReLU(dim*4),
42
+ nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
43
+ nn.InstanceNorm2d(dim*8, affine=True),
44
+ nn.PReLU(dim*8),
45
+ nn.AdaptiveMaxPool2d(1),
46
+ nn.Flatten(),
47
+ nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
48
+ nn.Dropout(0.3),
49
+ nn.PReLU(dim*4),
50
+ nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
51
+ LearnableSigmoid1D(1)
52
+ )
53
+
54
+ def forward(self, x, y):
55
+ xy = torch.stack((x, y), dim=1)
56
+ return self.layers(xy)
models/generator.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from .mamba_block import TFMambaBlock
5
+ from .codec_module import DenseEncoder, MagDecoder, PhaseDecoder
6
+
7
+ class SEMamba(nn.Module):
8
+ """
9
+ SEMamba model for speech enhancement using Mamba blocks.
10
+
11
+ This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
12
+ and phase decoders to process noisy magnitude and phase inputs.
13
+ """
14
+ def __init__(self, cfg):
15
+ """
16
+ Initialize the SEMamba model.
17
+
18
+ Args:
19
+ - cfg: Configuration object containing model parameters.
20
+ """
21
+ super(SEMamba, self).__init__()
22
+ self.cfg = cfg
23
+ self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4
24
+
25
+ # Initialize dense encoder
26
+ self.dense_encoder = DenseEncoder(cfg)
27
+
28
+ # Initialize Mamba blocks
29
+ self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])
30
+
31
+ # Initialize decoders
32
+ self.mask_decoder = MagDecoder(cfg)
33
+ self.phase_decoder = PhaseDecoder(cfg)
34
+
35
+ def forward(self, noisy_mag, noisy_pha):
36
+ """
37
+ Forward pass for the SEMamba model.
38
+
39
+ Args:
40
+ - noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T].
41
+ - noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T].
42
+
43
+ Returns:
44
+ - denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T].
45
+ - denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T].
46
+ - denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2].
47
+ """
48
+ # Reshape inputs
49
+ noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
50
+ noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F]
51
+
52
+ # Concatenate magnitude and phase inputs
53
+ x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
54
+
55
+ # Encode input
56
+ x = self.dense_encoder(x)
57
+
58
+ # Apply Mamba blocks
59
+ for block in self.TSMamba:
60
+ x = block(x)
61
+
62
+ # Decode magnitude and phase
63
+ denoised_mag = rearrange(self.mask_decoder(x) * noisy_mag, 'b c t f -> b f t c').squeeze(-1)
64
+ denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
65
+
66
+ # Combine denoised magnitude and phase into a complex representation
67
+ denoised_com = torch.stack(
68
+ (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
69
+ dim=-1
70
+ )
71
+
72
+ return denoised_mag, denoised_pha, denoised_com
models/loss.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from pesq import pesq
7
+ from joblib import Parallel, delayed
8
+
9
+ def phase_losses(phase_r, phase_g, cfg):
10
+ """
11
+ Calculate phase losses including in-phase loss, gradient delay loss,
12
+ and integrated absolute frequency loss between reference and generated phases.
13
+
14
+ Args:
15
+ phase_r (torch.Tensor): Reference phase tensor of shape (batch, freq, time).
16
+ phase_g (torch.Tensor): Generated phase tensor of shape (batch, freq, time).
17
+ h (object): Configuration object containing parameters like n_fft.
18
+
19
+ Returns:
20
+ tuple: Tuple containing in-phase loss, gradient delay loss, and integrated absolute frequency loss.
21
+ """
22
+ dim_freq = cfg['stft_cfg']['n_fft'] // 2 + 1 # Calculate frequency dimension
23
+ dim_time = phase_r.size(-1) # Calculate time dimension
24
+
25
+ # Construct gradient delay matrix
26
+ gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) -
27
+ torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) -
28
+ torch.eye(dim_freq)).to(phase_g.device)
29
+
30
+ # Apply gradient delay matrix to reference and generated phases
31
+ gd_r = torch.matmul(phase_r.permute(0, 2, 1), gd_matrix)
32
+ gd_g = torch.matmul(phase_g.permute(0, 2, 1), gd_matrix)
33
+
34
+ # Construct integrated absolute frequency matrix
35
+ iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) -
36
+ torch.triu(torch.ones(dim_time, dim_time), diagonal=2) -
37
+ torch.eye(dim_time)).to(phase_g.device)
38
+
39
+ # Apply integrated absolute frequency matrix to reference and generated phases
40
+ iaf_r = torch.matmul(phase_r, iaf_matrix)
41
+ iaf_g = torch.matmul(phase_g, iaf_matrix)
42
+
43
+ # Calculate losses
44
+ ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
45
+ gd_loss = torch.mean(anti_wrapping_function(gd_r - gd_g))
46
+ iaf_loss = torch.mean(anti_wrapping_function(iaf_r - iaf_g))
47
+
48
+ return ip_loss, gd_loss, iaf_loss
49
+
50
+ def anti_wrapping_function(x):
51
+ """
52
+ Anti-wrapping function to adjust phase values within the range of -pi to pi.
53
+
54
+ Args:
55
+ x (torch.Tensor): Input tensor representing phase differences.
56
+
57
+ Returns:
58
+ torch.Tensor: Adjusted tensor with phase values wrapped within -pi to pi.
59
+ """
60
+ return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
61
+
62
+ def compute_stft(y: torch.Tensor, n_fft: int, hop_size: int, win_size: int, center: bool, compress_factor: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ """
64
+ Compute the Short-Time Fourier Transform (STFT) and return magnitude, phase, and complex components.
65
+
66
+ Args:
67
+ y (torch.Tensor): Input signal tensor.
68
+ n_fft (int): Number of FFT points.
69
+ hop_size (int): Hop size for STFT.
70
+ win_size (int): Window size for STFT.
71
+ center (bool): Whether to pad the input on both sides.
72
+ compress_factor (float, optional): Compression factor for magnitude. Defaults to 1.0.
73
+
74
+ Returns:
75
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Magnitude, phase, and complex components.
76
+ """
77
+ eps = torch.finfo(y.dtype).eps
78
+ hann_window = torch.hann_window(win_size).to(y.device)
79
+
80
+ stft_spec = torch.stft(
81
+ y,
82
+ n_fft=n_fft,
83
+ hop_length=hop_size,
84
+ win_length=win_size,
85
+ window=hann_window,
86
+ center=center,
87
+ pad_mode='reflect',
88
+ normalized=False,
89
+ return_complex=True
90
+ )
91
+
92
+ real_part = stft_spec.real
93
+ imag_part = stft_spec.imag
94
+
95
+ mag = torch.sqrt( real_part.pow(2) * imag_part.pow(2) + eps )
96
+ pha = torch.atan2( real_part + eps, imag_part + eps )
97
+
98
+ mag = torch.pow(mag, compress_factor)
99
+ com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
100
+
101
+ return mag, pha, com
102
+
103
+ def pesq_score(utts_r, utts_g, cfg):
104
+ """
105
+ Calculate PESQ (Perceptual Evaluation of Speech Quality) score for pairs of reference and generated utterances.
106
+
107
+ Args:
108
+ utts_r (list of torch.Tensor): List of reference utterances.
109
+ utts_g (list of torch.Tensor): List of generated utterances.
110
+ h (object): Configuration object containing parameters like sampling_rate.
111
+
112
+ Returns:
113
+ float: Mean PESQ score across all pairs of utterances.
114
+ """
115
+ def eval_pesq(clean_utt, esti_utt, sr):
116
+ """
117
+ Evaluate PESQ score for a single pair of clean and estimated utterances.
118
+
119
+ Args:
120
+ clean_utt (np.ndarray): Clean reference utterance.
121
+ esti_utt (np.ndarray): Estimated generated utterance.
122
+ sr (int): Sampling rate.
123
+
124
+ Returns:
125
+ float: PESQ score or -1 in case of an error.
126
+ """
127
+ try:
128
+ pesq_score = pesq(sr, clean_utt, esti_utt)
129
+ except Exception as e:
130
+ # Error can happen due to silent period or other issues
131
+ print(f"Error computing PESQ score: {e}")
132
+ pesq_score = -1
133
+ return pesq_score
134
+
135
+ # Parallel processing of PESQ score computation
136
+ pesq_scores = Parallel(n_jobs=30)(delayed(eval_pesq)(
137
+ utts_r[i].squeeze().cpu().numpy(),
138
+ utts_g[i].squeeze().cpu().numpy(),
139
+ cfg['stft_cfg']['sampling_rate']
140
+ ) for i in range(len(utts_r)))
141
+
142
+ # Calculate mean PESQ score
143
+ pesq_score = np.mean(pesq_scores)
144
+ return pesq_score
145
+
models/lsigmoid.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/utils.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class LearnableSigmoid1D(nn.Module):
7
+ """
8
+ Learnable Sigmoid Activation Function for 1D inputs.
9
+
10
+ This module applies a learnable slope parameter to the sigmoid activation function.
11
+ """
12
+ def __init__(self, in_features, beta=1):
13
+ """
14
+ Initialize the LearnableSigmoid1D module.
15
+
16
+ Args:
17
+ - in_features (int): Number of input features.
18
+ - beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1.
19
+ """
20
+ super(LearnableSigmoid1D, self).__init__()
21
+ self.beta = beta
22
+ self.slope = nn.Parameter(torch.ones(in_features))
23
+ self.slope.requires_grad = True
24
+
25
+ def forward(self, x):
26
+ """
27
+ Forward pass for the LearnableSigmoid1D module.
28
+
29
+ Args:
30
+ - x (torch.Tensor): Input tensor.
31
+
32
+ Returns:
33
+ - torch.Tensor: Output tensor after applying the learnable sigmoid activation.
34
+ """
35
+ return self.beta * torch.sigmoid(self.slope * x)
36
+
37
+ class LearnableSigmoid2D(nn.Module):
38
+ """
39
+ Learnable Sigmoid Activation Function for 2D inputs.
40
+
41
+ This module applies a learnable slope parameter to the sigmoid activation function for 2D inputs.
42
+ """
43
+ def __init__(self, in_features, beta=1):
44
+ """
45
+ Initialize the LearnableSigmoid2D module.
46
+
47
+ Args:
48
+ - in_features (int): Number of input features.
49
+ - beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1.
50
+ """
51
+ super(LearnableSigmoid2D, self).__init__()
52
+ self.beta = beta
53
+ self.slope = nn.Parameter(torch.ones(in_features, 1))
54
+ self.slope.requires_grad = True
55
+
56
+ def forward(self, x):
57
+ """
58
+ Forward pass for the LearnableSigmoid2D module.
59
+
60
+ Args:
61
+ - x (torch.Tensor): Input tensor.
62
+
63
+ Returns:
64
+ - torch.Tensor: Output tensor after applying the learnable sigmoid activation.
65
+ """
66
+ return self.beta * torch.sigmoid(self.slope * x)
models/mamba_block.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: https://github.com/state-spaces/mamba/blob/9127d1f47f367f5c9cc49c73ad73557089d02cb8/mamba_ssm/models/mixer_seq_simple.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+ from torch.nn.parameter import Parameter
8
+ from functools import partial
9
+ from einops import rearrange
10
+
11
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
12
+ from mamba_ssm.models.mixer_seq_simple import _init_weights
13
+ from mamba_ssm.ops.triton.layernorm import RMSNorm
14
+
15
+ # github: https://github.com/state-spaces/mamba/blob/9127d1f47f367f5c9cc49c73ad73557089d02cb8/mamba_ssm/models/mixer_seq_simple.py
16
+ def create_block(
17
+ d_model, cfg, layer_idx=0, rms_norm=True, fused_add_norm=False, residual_in_fp32=False,
18
+ ):
19
+ d_state = cfg['model_cfg']['d_state'] # 16
20
+ d_conv = cfg['model_cfg']['d_conv'] # 4
21
+ expand = cfg['model_cfg']['expand'] # 4
22
+ norm_epsilon = cfg['model_cfg']['norm_epsilon'] # 0.00001
23
+
24
+ mixer_cls = partial(Mamba, layer_idx=layer_idx, d_state=d_state, d_conv=d_conv, expand=expand)
25
+ norm_cls = partial(
26
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon
27
+ )
28
+ block = Block(
29
+ d_model,
30
+ mixer_cls,
31
+ norm_cls=norm_cls,
32
+ fused_add_norm=fused_add_norm,
33
+ residual_in_fp32=residual_in_fp32,
34
+ )
35
+ block.layer_idx = layer_idx
36
+ return block
37
+
38
+ class MambaBlock(nn.Module):
39
+ def __init__(self, in_channels, cfg):
40
+ super(MambaBlock, self).__init__()
41
+ n_layer = 1
42
+ self.forward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) )
43
+ self.backward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) )
44
+
45
+ self.apply(
46
+ partial(
47
+ _init_weights,
48
+ n_layer=n_layer,
49
+ )
50
+ )
51
+
52
+ def forward(self, x):
53
+ x_forward, x_backward = x.clone(), torch.flip(x, [1])
54
+ resi_forward, resi_backward = None, None
55
+
56
+ # Forward
57
+ for layer in self.forward_blocks:
58
+ x_forward, resi_forward = layer(x_forward, resi_forward)
59
+ y_forward = (x_forward + resi_forward) if resi_forward is not None else x_forward
60
+
61
+ # Backward
62
+ for layer in self.backward_blocks:
63
+ x_backward, resi_backward = layer(x_backward, resi_backward)
64
+ y_backward = torch.flip((x_backward + resi_backward), [1]) if resi_backward is not None else torch.flip(x_backward, [1])
65
+
66
+ return torch.cat([y_forward, y_backward], -1)
67
+
68
+ class TFMambaBlock(nn.Module):
69
+ """
70
+ Temporal-Frequency Mamba block for sequence modeling.
71
+
72
+ Attributes:
73
+ cfg (Config): Configuration for the block.
74
+ time_mamba (MambaBlock): Mamba block for temporal dimension.
75
+ freq_mamba (MambaBlock): Mamba block for frequency dimension.
76
+ tlinear (ConvTranspose1d): ConvTranspose1d layer for temporal dimension.
77
+ flinear (ConvTranspose1d): ConvTranspose1d layer for frequency dimension.
78
+ """
79
+ def __init__(self, cfg):
80
+ super(TFMambaBlock, self).__init__()
81
+ self.cfg = cfg
82
+ self.hid_feature = cfg['model_cfg']['hid_feature']
83
+
84
+ # Initialize Mamba blocks
85
+ self.time_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
86
+ self.freq_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
87
+
88
+ # Initialize ConvTranspose1d layers
89
+ self.tlinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
90
+ self.flinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
91
+
92
+ def forward(self, x):
93
+ """
94
+ Forward pass of the TFMamba block.
95
+
96
+ Parameters:
97
+ x (Tensor): Input tensor with shape (batch, channels, time, freq).
98
+
99
+ Returns:
100
+ Tensor: Output tensor after applying temporal and frequency Mamba blocks.
101
+ """
102
+ b, c, t, f = x.size()
103
+
104
+ x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
105
+ x = self.tlinear( self.time_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
106
+ x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
107
+ x = self.flinear( self.freq_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
108
+ x = x.view(b, t, f, c).permute(0, 3, 1, 2)
109
+ return x
110
+
models/pcs400.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ import argparse
6
+ import librosa
7
+ import scipy
8
+
9
+ # PCS400 parameters
10
+ PCS400 = np.ones(201)
11
+ PCS400[0:3] = 1
12
+ PCS400[3:5] = 1.070175439
13
+ PCS400[5:8] = 1.182456140
14
+ PCS400[8:10] = 1.287719298
15
+ PCS400[10:110] = 1.4 # Pre Set
16
+ PCS400[110:130] = 1.322807018
17
+ PCS400[130:160] = 1.238596491
18
+ PCS400[160:190] = 1.161403509
19
+ PCS400[190:202] = 1.077192982
20
+
21
+ maxv = np.iinfo(np.int16).max
22
+
23
+ def Sp_and_phase(signal):
24
+ signal_length = signal.shape[0]
25
+ n_fft = 400
26
+ hop_length = 100
27
+ y_pad = librosa.util.fix_length(signal, size=signal_length + n_fft // 2)
28
+
29
+ F = librosa.stft(y_pad, n_fft=400, hop_length=100, win_length=400, window=scipy.signal.windows.hamming(400))
30
+ Lp = PCS400 * np.transpose(np.log1p(np.abs(F)), (1, 0))
31
+ phase = np.angle(F)
32
+
33
+ NLp = np.transpose(Lp, (1, 0))
34
+
35
+ return NLp, phase, signal_length
36
+
37
+
38
+ def SP_to_wav(mag, phase, signal_length):
39
+ mag = np.expm1(mag)
40
+ Rec = np.multiply(mag, np.exp(1j*phase))
41
+ result = librosa.istft(Rec,
42
+ hop_length=100,
43
+ win_length=400,
44
+ window=scipy.signal.windows.hamming(400),
45
+ length=signal_length)
46
+ return result
47
+
48
+ def cal_pcs(signal_wav):
49
+ noisy_LP, Nphase, signal_length = Sp_and_phase(signal_wav.squeeze())
50
+ enhanced_wav = SP_to_wav(noisy_LP, Nphase, signal_length)
51
+ enhanced_wav = enhanced_wav/np.max(abs(enhanced_wav))
52
+
53
+ return enhanced_wav
models/stfts.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False):
5
+ """
6
+ Compute magnitude and phase using STFT.
7
+
8
+ Args:
9
+ y (torch.Tensor): Input audio signal.
10
+ n_fft (int): FFT size.
11
+ hop_size (int): Hop size.
12
+ win_size (int): Window size.
13
+ compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
14
+ center (bool, optional): Whether to center the signal before padding. Defaults to True.
15
+ eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False.
16
+
17
+ Returns:
18
+ tuple: Magnitude, phase, and complex representation of the STFT.
19
+ """
20
+ #eps = torch.finfo(y.dtype).eps
21
+ eps = 1e-10
22
+ hann_window = torch.hann_window(win_size).to(y.device)
23
+ stft_spec = torch.stft(
24
+ y, n_fft,
25
+ hop_length=hop_size,
26
+ win_length=win_size,
27
+ window=hann_window,
28
+ center=center,
29
+ pad_mode='reflect',
30
+ normalized=False,
31
+ return_complex=True)
32
+
33
+ if addeps==False:
34
+ mag = torch.abs(stft_spec)
35
+ pha = torch.angle(stft_spec)
36
+ else:
37
+ real_part = stft_spec.real
38
+ imag_part = stft_spec.imag
39
+ mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps)
40
+ pha = torch.atan2(imag_part + eps, real_part + eps)
41
+ # Compress the magnitude
42
+ mag = torch.pow(mag, compress_factor)
43
+ com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
44
+ return mag, pha, com
45
+
46
+
47
+ def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
48
+ """
49
+ Inverse STFT to reconstruct the audio signal from magnitude and phase.
50
+
51
+ Args:
52
+ mag (torch.Tensor): Magnitude of the STFT.
53
+ pha (torch.Tensor): Phase of the STFT.
54
+ n_fft (int): FFT size.
55
+ hop_size (int): Hop size.
56
+ win_size (int): Window size.
57
+ compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0.
58
+ center (bool, optional): Whether to center the signal before padding. Defaults to True.
59
+
60
+ Returns:
61
+ torch.Tensor: Reconstructed audio signal.
62
+ """
63
+ mag = torch.pow(mag, 1.0 / compress_factor)
64
+ com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha))
65
+ hann_window = torch.hann_window(win_size).to(com.device)
66
+ wav = torch.istft(
67
+ com,
68
+ n_fft,
69
+ hop_length=hop_size,
70
+ win_length=win_size,
71
+ window=hann_window,
72
+ center=center)
73
+ return wav
recipes/SEMamba_advanced.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment Settings
2
+ # These settings specify the hardware and distributed setup for the model training.
3
+ # Adjust `num_gpus` and `dist_config` according to your distributed training environment.
4
+ env_setting:
5
+ num_gpus: 2 # Number of GPUs. Now we don't support CPU mode.
6
+ num_workers: 20 # Number of worker threads for data loading.
7
+ seed: 1234 # Seed for random number generators to ensure reproducibility.
8
+ stdout_interval: 10
9
+ checkpoint_interval: 1000 # save model to ckpt every N steps
10
+ validation_interval: 1000
11
+ summary_interval: 100
12
+ dist_cfg:
13
+ dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs.
14
+ dist_url: tcp://localhost:19477 # URL for initializing distributed training.
15
+ world_size: 1 # Total number of processes in the distributed training.
16
+
17
+ # Datapath Configuratoin
18
+ data_cfg:
19
+ train_clean_json: data/train_clean.json
20
+ train_noisy_json: data/train_noisy.json
21
+ valid_clean_json: data/valid_clean.json
22
+ valid_noisy_json: data/valid_noisy.json
23
+ test_clean_json: data/test_clean.json
24
+ test_noisy_json: data/test_noisy.json
25
+
26
+ # Training Configuration
27
+ # This section details parameters that directly influence the training process,
28
+ # including batch sizes, learning rates, and optimizer specifics.
29
+ training_cfg:
30
+ training_epochs: 200 # Training epoch.
31
+ batch_size: 4 # Training batch size.
32
+ learning_rate: 0.0005 # Initial learning rate.
33
+ adam_b1: 0.8 # Beta1 hyperparameter for the AdamW optimizer.
34
+ adam_b2: 0.99 # Beta2 hyperparameter for the AdamW optimizer.
35
+ lr_decay: 0.99 # Learning rate decay per epoch.
36
+ segment_size: 32000 # Audio segment size used during training, dependent on sampling rate.
37
+ loss:
38
+ metric: 0.05
39
+ magnitude: 0.9
40
+ phase: 0.3
41
+ complex: 0.1
42
+ time: 0.2
43
+ consistancy: 0.1
44
+ use_PCS400: False # Use PCS or not
45
+
46
+ # STFT Configuration
47
+ # Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
48
+ stft_cfg:
49
+ sampling_rate: 16000 # Audio sampling rate in Hz.
50
+ n_fft: 400 # FFT components for transforming audio signals.
51
+ hop_size: 100 # Samples between successive frames.
52
+ win_size: 400 # Window size used in FFT.
53
+
54
+ # Model Configuration
55
+ # Defines the architecture specifics of the model, including layer configurations and feature compression.
56
+ model_cfg:
57
+ hid_feature: 64 # Channels in dense layers.
58
+ compress_factor: 0.3 # Compression factor applied to extracted features.
59
+ num_tfmamba: 4 # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
60
+ d_state: 16 # Dimensionality of the state vector in Mamba blocks.
61
+ d_conv: 4 # Convolutional layer dimensionality within Mamba blocks.
62
+ expand: 4 # Expansion factor for the layers within the Mamba blocks.
63
+ norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks.
64
+ beta: 2.0 # Hyperparameter for the Learnable Sigmoid function.
65
+ input_channel: 2 # Magnitude and Phase
66
+ output_channel: 1 # Single Channel Speech Enhancement
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ packaging
2
+ librosa
3
+ soundfile
4
+ pyyaml
5
+ argparse
6
+ tensorboard
7
+ pesq
8
+ einops
9
+ matplotlib
10
+ torch==2.5.1
11
+ torchaudio==2.5.1
12
+ numpy==1.26.4
13
+ ultralytics
14
+ moviepy
15
+ supervision
16
+ opencv-python
17
+ ffmpeg-python
18
+ decord==0.6.0
19
+ pytorch_lightning==1.9.0
20
+ typeguard==2.13.3
21
+ torch_complex
22
+ rich
yolov8n-face.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d17b38523a994b13ee604b67f02791ca0f43b9f446a32fd7bc44e17c56ead077
3
+ size 6250099